Commit 7d79afd4 authored by omar92's avatar omar92
Browse files

Merge branch 'master_old'

parents a4049989 96b57a9a
name: "Windows Release cu118 dependencies 2"
on:
workflow_dispatch:
# push:
# branches:
# - master
jobs:
build_dependencies:
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10.9'
- shell: bash
run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu118_python_deps
tar cf cu118_python_deps.tar cu118_python_deps
- uses: actions/cache/save@v3
with:
path: cu118_python_deps.tar
key: ${{ runner.os }}-build-cu118
...@@ -32,14 +32,28 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ...@@ -32,14 +32,28 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
## Shortcuts ## Shortcuts
- **Ctrl + A** select all nodes
- **Ctrl + M** mute/unmute selected nodes | Keybind | Explanation |
- **Delete** or **Backspace** delete selected nodes | - | - |
- **Space** Holding space key while moving the cursor moves the canvas around. It works when holding the mouse button down so it is easier to connect different nodes when the canvas gets too large. | Ctrl + Enter | Queue up current graph for generation |
- **Ctrl/Shift + Click** Add clicked node to selection. | Ctrl + Shift + Enter | Queue up current graph as first for generation |
- **Ctrl + C/Ctrl + V** - Copy and paste selected nodes, without maintaining the connection to the outputs of unselected nodes. | Ctrl + S | Save workflow |
- **Ctrl + C/Ctrl + Shift + V** - Copy and paste selected nodes, and maintaining the connection from the outputs of unselected nodes to the inputs of the newly pasted nodes. | Ctrl + O | Load workflow |
- Holding **Shift** and drag selected nodes - Move multiple selected nodes at the same time. | Ctrl + A | Select all nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Delete/Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
Ctrl can also be replaced with Cmd instead for MacOS users
# Installing # Installing
...@@ -69,7 +83,7 @@ Put your VAE in: models/vae ...@@ -69,7 +83,7 @@ Put your VAE in: models/vae
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10. At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
### AMD (Linux only) ### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2```
......
import torch
from torch import nn, einsum
from ldm.modules.attention import CrossAttention
from inspect import isfunction
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
N_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense2(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
B, N_visual, _ = x.shape
B, N_ground, _ = objs.shape
objs = self.linear(objs)
# sanity check
size_v = math.sqrt(N_visual)
size_g = math.sqrt(N_ground)
assert int(size_v) == size_v, "Visual tokens must be square rootable"
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
size_v = int(size_v)
size_g = int(size_g)
# select grounding token and resize it to visual token size as residual
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
:, N_visual:, :]
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
out = torch.nn.functional.interpolate(
out, (size_v, size_v), mode='bicubic')
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
# add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):
self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
@torch.no_grad()
def __call__(self, x, cat_dim=-1):
"x: arbitrary shape of tensor. dim: cat dim"
out = []
for freq in self.freq_bands:
out.append(torch.sin(freq * x))
out.append(torch.cos(freq * x))
return torch.cat(out, cat_dim)
class PositionNet(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
torch.zeros([self.in_dim]))
self.null_position_feature = torch.nn.Parameter(
torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
assert objs.shape == torch.Size([B, N, self.out_dim])
return objs
class Gligen(nn.Module):
def __init__(self, modules, position_net, key_dim):
super().__init__()
self.module_list = nn.ModuleList(modules)
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu")
boxes = []
positive_embeddings = []
for p in position_params:
x1 = (p[4]) / w
y1 = (p[3]) / h
x2 = (p[4] + p[2]) / w
y2 = (p[3] + p[1]) / h
masks[len(boxes)] = 1.0
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
positive_embeddings += [p[0]]
append_boxes = []
append_conds = []
if len(boxes) < self.max_objs:
append_boxes = [torch.zeros(
[self.max_objs - len(boxes), 4], device="cpu")]
append_conds = [torch.zeros(
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
box_out = torch.cat(
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
masks = masks.unsqueeze(0).repeat(batch, 1)
conds = torch.cat(positive_embeddings +
append_conds).unsqueeze(0).repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def set_empty(self, latent_image_shape, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
box_out = torch.zeros([self.max_objs, 4],
device="cpu").repeat(batch, 1, 1)
conds = torch.zeros([self.max_objs, self.key_dim],
device="cpu").repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def cleanup(self):
pass
def get_models(self):
return [self]
def load_gligen(sd):
sd_k = sd.keys()
output_list = []
key_dim = 768
for a in ["input_blocks", "middle_block", "output_blocks"]:
for b in range(20):
k_temp = filter(lambda k: "{}.{}.".format(a, b)
in k and ".fuser." in k, sd_k)
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
n_sd = {}
for k in k_temp:
n_sd[k[1]] = sd[k[0]]
if len(n_sd) > 0:
query_dim = n_sd["linear.weight"].shape[0]
key_dim = n_sd["linear.weight"].shape[1]
if key_dim == 768: # SD1.x
n_heads = 8
d_head = query_dim // n_heads
else:
d_head = 64
n_heads = query_dim // d_head
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
in_dim = sd["position_net.null_positive_feature"].shape[0]
out_dim = sd["position_net.linears.4.weight"].shape[0]
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen
...@@ -9,7 +9,7 @@ from typing import Optional, Any ...@@ -9,7 +9,7 @@ from typing import Optional, Any
from ldm.modules.diffusionmodules.util import checkpoint from ldm.modules.diffusionmodules.util import checkpoint
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
import model_management from comfy import model_management
from . import tomesd from . import tomesd
...@@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module): ...@@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}): def _forward(self, x, context=None, transformer_options={}):
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
n = self.norm1(x) n = self.norm1(x)
if "tomesd" in transformer_options: if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
...@@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module): ...@@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
n = self.attn1(n, context=context if self.disable_self_attn else None) n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
n = self.norm2(x) n = self.norm2(x)
n = self.attn2(n, context=context) n = self.attn2(n, context=context)
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x return x
......
...@@ -7,7 +7,7 @@ from einops import rearrange ...@@ -7,7 +7,7 @@ from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management from comfy import model_management
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers
......
...@@ -782,6 +782,8 @@ class UNetModel(nn.Module): ...@@ -782,6 +782,8 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape) transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
......
...@@ -24,7 +24,7 @@ except ImportError: ...@@ -24,7 +24,7 @@ except ImportError:
from torch import Tensor from torch import Tensor
from typing import List from typing import List
import model_management from comfy import model_management
def dynamic_slice( def dynamic_slice(
x: Tensor, x: Tensor,
......
...@@ -176,7 +176,7 @@ def load_model_gpu(model): ...@@ -176,7 +176,7 @@ def load_model_gpu(model):
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models): def load_controlnet_gpu(control_models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
...@@ -186,6 +186,10 @@ def load_controlnet_gpu(models): ...@@ -186,6 +186,10 @@ def load_controlnet_gpu(models):
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
models = []
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets: for m in current_gpu_controlnets:
if m not in models: if m not in models:
m.cpu() m.cpu()
......
...@@ -3,7 +3,7 @@ from .k_diffusion import external as k_diffusion_external ...@@ -3,7 +3,7 @@ from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import contextlib import contextlib
import model_management from comfy import model_management
from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
...@@ -36,8 +36,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -36,8 +36,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
strength = cond[1]['strength'] strength = cond[1]['strength']
adm_cond = None adm_cond = None
if 'adm' in cond[1]: if 'adm_encoded' in cond[1]:
adm_cond = cond[1]['adm'] adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength mult = torch.ones_like(input_x) * strength
...@@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
control = None control = None
if 'control' in cond[1]: if 'control' in cond[1]:
control = cond[1]['control'] control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
...@@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape: if c1[0].shape != c2[0].shape:
return False return False
#control
if (c1[4] is None) != (c2[4] is None): if (c1[4] is None) != (c2[4] is None):
return False return False
if c1[4] is not None: if c1[4] is not None:
if c1[4] is not c2[4]: if c1[4] is not c2[4]:
return False return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2]) return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list): def cond_cat(c_list):
...@@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cond_or_uncond = [] cond_or_uncond = []
area = [] area = []
control = None control = None
patches = None
for x in to_batch: for x in to_batch:
o = to_run.pop(x) o = to_run.pop(x)
p = o[0] p = o[0]
...@@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
area += [p[3]] area += [p[3]]
cond_or_uncond += [o[1]] cond_or_uncond += [o[1]]
control = p[4] control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)
...@@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options'] transformer_options = model_options['transformer_options'].copy()
if patches is not None:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
...@@ -211,7 +242,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -211,7 +242,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area() max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
return uncond + (cond - uncond) * cond_scale if "sampler_cfg_function" in model_options:
return model_options["sampler_cfg_function"](cond, uncond, cond_scale)
else:
return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
...@@ -306,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c): ...@@ -306,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy() n = c[1].copy()
conds += [[smallest[0], n]] conds += [[smallest[0], n]]
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
def apply_control_net_to_equal_area(conds, uncond):
cond_cnets = [] cond_cnets = []
cond_other = [] cond_other = []
uncond_cnets = [] uncond_cnets = []
...@@ -315,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond): ...@@ -315,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
if 'area' not in x[1]: if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None: if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1]['control']) cond_cnets.append(x[1][name])
else: else:
cond_other.append((x, t)) cond_other.append((x, t))
for t in range(len(uncond)): for t in range(len(uncond)):
x = uncond[t] x = uncond[t]
if 'area' not in x[1]: if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None: if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1]['control']) uncond_cnets.append(x[1][name])
else: else:
uncond_other.append((x, t)) uncond_other.append((x, t))
...@@ -333,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond): ...@@ -333,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond):
for x in range(len(cond_cnets)): for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)] temp = uncond_other[x % len(uncond_other)]
o = temp[0] o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None: if name in o[1] and o[1][name] is not None:
n = o[1].copy() n = o[1].copy()
n['control'] = cond_cnets[x] n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]] uncond += [[o[0], n]]
else: else:
n = o[1].copy() n = o[1].copy()
n['control'] = cond_cnets[x] n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device): def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
...@@ -371,10 +405,11 @@ def encode_adm(noise_augmentor, conds, batch_size, device): ...@@ -371,10 +405,11 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
else: else:
adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device)
x[1] = x[1].copy() x[1] = x[1].copy()
x[1]["adm"] = torch.cat([adm_out] * batch_size) x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size)
return conds return conds
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
...@@ -463,7 +498,8 @@ class KSampler: ...@@ -463,7 +498,8 @@ class KSampler:
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative) apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16: if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
......
...@@ -4,7 +4,7 @@ import copy ...@@ -4,7 +4,7 @@ import copy
import sd1_clip import sd1_clip
import sd2_clip import sd2_clip
import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL
import yaml import yaml
...@@ -13,6 +13,7 @@ from .t2i_adapter import adapter ...@@ -13,6 +13,7 @@ from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision from . import clip_vision
from . import gligen
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
...@@ -250,6 +251,9 @@ class ModelPatcher: ...@@ -250,6 +251,9 @@ class ModelPatcher:
def set_model_tomesd(self, ratio): def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}
def set_model_sampler_cfg_function(self, sampler_cfg_function):
self.model_options["sampler_cfg_function"] = sampler_cfg_function
def model_dtype(self): def model_dtype(self):
return self.model.diffusion_model.dtype return self.model.diffusion_model.dtype
...@@ -372,10 +376,12 @@ class CLIP: ...@@ -372,10 +376,12 @@ class CLIP:
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
self.layer_idx = layer_idx self.layer_idx = layer_idx
def encode(self, text): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens, return_pooled=False):
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
tokens = self.tokenizer.tokenize_with_weights(text)
try: try:
self.patcher.patch_model() self.patcher.patch_model()
cond = self.cond_stage_model.encode_token_weights(tokens) cond = self.cond_stage_model.encode_token_weights(tokens)
...@@ -383,8 +389,16 @@ class CLIP: ...@@ -383,8 +389,16 @@ class CLIP:
except Exception as e: except Exception as e:
self.patcher.unpatch_model() self.patcher.unpatch_model()
raise e raise e
if return_pooled:
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
pooled = cond[:, eos_token_index]
return cond, pooled
return cond return cond
def encode(self, text):
tokens = self.tokenize(text)
return self.encode_from_tokens(tokens)
class VAE: class VAE:
def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None): def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
if config is None: if config is None:
...@@ -555,10 +569,10 @@ class ControlNet: ...@@ -555,10 +569,10 @@ class ControlNet:
c.strength = self.strength c.strength = self.strength
return c return c
def get_control_models(self): def get_models(self):
out = [] out = []
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models() out += self.previous_controlnet.get_models()
out.append(self.control_model) out.append(self.control_model)
return out return out
...@@ -728,10 +742,10 @@ class T2IAdapter: ...@@ -728,10 +742,10 @@ class T2IAdapter:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
def get_control_models(self): def get_models(self):
out = [] out = []
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models() out += self.previous_controlnet.get_models()
return out return out
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
...@@ -778,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None): ...@@ -778,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None):
clip.load_from_state_dict(clip_data) clip.load_from_state_dict(clip_data)
return clip return clip
def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path)
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
with open(config_path, 'r') as stream: with open(config_path, 'r') as stream:
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
......
...@@ -260,60 +260,97 @@ class SD1Tokenizer: ...@@ -260,60 +260,97 @@ class SD1Tokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory self.embedding_directory = embedding_directory
self.max_word_length = 8 self.max_word_length = 8
self.embedding_identifier = "embedding:"
def _try_get_embedding(self, embedding_name:str):
'''
Takes a potential embedding name and tries to retrieve it.
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
'''
embed = load_embed(embedding_name, self.embedding_directory)
if embed is None:
stripped = embedding_name.strip(',')
if len(stripped) < len(embedding_name):
embed = load_embed(stripped, self.embedding_directory)
return (embed, embedding_name[len(stripped):])
return (embed, "")
def tokenize_with_weights(self, text:str, return_word_ids=False):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
if self.pad_with_end:
pad_token = self.end_token
else:
pad_token = 0
def tokenize_with_weights(self, text):
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) parsed_weights = token_weights(text, 1.0)
#tokenize words
tokens = [] tokens = []
for t in parsed_weights: for weighted_segment, weight in parsed_weights:
to_tokenize = unescape_important(t[0]).replace("\n", " ").split(' ') to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
while len(to_tokenize) > 0: to_tokenize = [x for x in to_tokenize if x != ""]
word = to_tokenize.pop(0) for word in to_tokenize:
temp_tokens = [] #if we find an embedding, deal with the embedding
embedding_identifier = "embedding:" if word.startswith(self.embedding_identifier) and self.embedding_directory is not None:
if word.startswith(embedding_identifier) and self.embedding_directory is not None: embedding_name = word[len(self.embedding_identifier):].strip('\n')
embedding_name = word[len(embedding_identifier):].strip('\n') embed, leftover = self._try_get_embedding(embedding_name)
embed = load_embed(embedding_name, self.embedding_directory)
if embed is None: if embed is None:
stripped = embedding_name.strip(',') print(f"warning, embedding:{embedding_name} does not exist, ignoring")
if len(stripped) < len(embedding_name): else:
embed = load_embed(stripped, self.embedding_directory)
if embed is not None:
to_tokenize.insert(0, embedding_name[len(stripped):])
if embed is not None:
if len(embed.shape) == 1: if len(embed.shape) == 1:
temp_tokens += [(embed, t[1])] tokens.append([(embed, weight)])
else: else:
for x in range(embed.shape[0]): tokens.append([(embed[x], weight) for x in range(embed.shape[0])])
temp_tokens += [(embed[x], t[1])] #if we accidentally have leftover text, continue parsing using leftover, else move on to next word
if leftover != "":
word = leftover
else: else:
print("warning, embedding:{} does not exist, ignoring".format(embedding_name)) continue
elif len(word) > 0: #parse word
tt = self.tokenizer(word)["input_ids"][1:-1] tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
for x in tt:
temp_tokens += [(x, t[1])] #reshape token array to CLIP input size
tokens_left = self.max_tokens_per_section - (len(tokens) % self.max_tokens_per_section) batched_tokens = []
batch = [(self.start_token, 1.0, 0)]
#try not to split words in different sections batched_tokens.append(batch)
if tokens_left < len(temp_tokens) and len(temp_tokens) < (self.max_word_length): for i, t_group in enumerate(tokens):
for x in range(tokens_left): #determine if we're going to try and keep the tokens in a single batch
tokens += [(self.end_token, 1.0)] is_large = len(t_group) >= self.max_word_length
tokens += temp_tokens
while len(t_group) > 0:
if len(t_group) + len(batch) > self.max_length - 1:
remaining_length = self.max_length - len(batch) - 1
#break word in two and add end token
if is_large:
batch.extend([(t,w,i+1) for t,w in t_group[:remaining_length]])
batch.append((self.end_token, 1.0, 0))
t_group = t_group[remaining_length:]
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []
out_tokens = [] #fill last batch
for x in range(0, len(tokens), self.max_tokens_per_section): batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
o_token = [(self.start_token, 1.0)] + tokens[x:min(self.max_tokens_per_section + x, len(tokens))]
o_token += [(self.end_token, 1.0)]
if self.pad_with_end:
o_token +=[(self.end_token, 1.0)] * (self.max_length - len(o_token))
else:
o_token +=[(0, 1.0)] * (self.max_length - len(o_token))
out_tokens += [o_token] if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
return batched_tokens
return out_tokens
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
import sd1_clip from comfy import sd1_clip
import torch import torch
import os import os
......
import os import os
from comfy_extras.chainner_models import model_loading from comfy_extras.chainner_models import model_loading
import model_management from comfy import model_management
import torch import torch
import comfy.utils import comfy.utils
import folder_paths import folder_paths
......
...@@ -18,6 +18,7 @@ a111: ...@@ -18,6 +18,7 @@ a111:
#other_ui: #other_ui:
# base_path: path/to/ui # base_path: path/to/ui
# checkpoints: models/checkpoints # checkpoints: models/checkpoints
# gligen: models/gligen
# custom_nodes: path/custom_nodes
...@@ -12,8 +12,8 @@ except: ...@@ -12,8 +12,8 @@ except:
folder_names_and_paths = {} folder_names_and_paths = {}
base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
...@@ -26,8 +26,13 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] ...@@ -26,8 +26,13 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
......
...@@ -81,6 +81,14 @@ if __name__ == "__main__": ...@@ -81,6 +81,14 @@ if __name__ == "__main__":
server = server.PromptServer(loop) server = server.PromptServer(loop)
q = execution.PromptQueue(server) q = execution.PromptQueue(server)
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
init_custom_nodes() init_custom_nodes()
server.add_routes() server.add_routes()
hijack_progress(server) hijack_progress(server)
...@@ -91,13 +99,6 @@ if __name__ == "__main__": ...@@ -91,13 +99,6 @@ if __name__ == "__main__":
dont_print = args.dont_print_server dont_print = args.dont_print_server
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
if args.output_directory: if args.output_directory:
output_dir = os.path.abspath(args.output_directory) output_dir = os.path.abspath(args.output_directory)
......
...@@ -21,16 +21,16 @@ import comfy.utils ...@@ -21,16 +21,16 @@ import comfy.utils
import comfy.clip_vision import comfy.clip_vision
import model_management import comfy.model_management
import importlib import importlib
import folder_paths import folder_paths
def before_node_execution(): def before_node_execution():
model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True): def interrupt_processing(value=True):
model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=8192 MAX_RESOLUTION=8192
...@@ -241,7 +241,7 @@ class DiffusersLoader: ...@@ -241,7 +241,7 @@ class DiffusersLoader:
model_path = os.path.join(search_path, model_path) model_path = os.path.join(search_path, model_path)
break break
return comfy.diffusers_convert.load_diffusers(model_path, fp16=model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
...@@ -490,6 +490,51 @@ class unCLIPConditioning: ...@@ -490,6 +490,51 @@ class unCLIPConditioning:
c.append(n) c.append(n)
return (c, ) return (c, )
class GLIGENLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
RETURN_TYPES = ("GLIGEN",)
FUNCTION = "load_gligen"
CATEGORY = "_for_testing/gligen"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)
class GLIGENTextBoxApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
"clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ),
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "_for_testing/gligen"
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
c.append(n)
return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
...@@ -510,6 +555,24 @@ class EmptyLatentImage: ...@@ -510,6 +555,24 @@ class EmptyLatentImage:
return ({"samples":latent}, ) return ({"samples":latent}, )
class LatentFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "rotate"
CATEGORY = "latent"
def rotate(self, samples, batch_index):
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
s["samples"] = s_in[batch_index:batch_index + 1].clone()
s["batch_index"] = batch_index
return (s,)
class LatentUpscale: class LatentUpscale:
upscale_methods = ["nearest-exact", "bilinear", "area"] upscale_methods = ["nearest-exact", "bilinear", "area"]
...@@ -680,12 +743,19 @@ class SetLatentNoiseMask: ...@@ -680,12 +743,19 @@ class SetLatentNoiseMask:
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"] latent_image = latent["samples"]
noise_mask = None noise_mask = None
device = model_management.get_torch_device() device = comfy.model_management.get_torch_device()
if disable_noise: if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else:
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") batch_index = 0
if "batch_index" in latent:
batch_index = latent["batch_index"]
generator = torch.manual_seed(seed)
for i in range(batch_index):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if "noise_mask" in latent: if "noise_mask" in latent:
noise_mask = latent['noise_mask'] noise_mask = latent['noise_mask']
...@@ -696,7 +766,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ...@@ -696,7 +766,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise_mask = noise_mask.to(device) noise_mask = noise_mask.to(device)
real_model = None real_model = None
model_management.load_model_gpu(model) comfy.model_management.load_model_gpu(model)
real_model = model.model real_model = model.model
noise = noise.to(device) noise = noise.to(device)
...@@ -706,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ...@@ -706,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
negative_copy = [] negative_copy = []
control_nets = [] control_nets = []
def get_models(cond):
models = []
for c in cond:
if 'control' in c[1]:
models += [c[1]['control']]
if 'gligen' in c[1]:
models += [c[1]['gligen'][1]]
return models
for p in positive: for p in positive:
t = p[0] t = p[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]] positive_copy += [[t] + p[1:]]
for n in negative: for n in negative:
t = n[0] t = n[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in n[1]:
control_nets += [n[1]['control']]
negative_copy += [[t] + n[1:]] negative_copy += [[t] + n[1:]]
control_net_models = [] models = get_models(positive) + get_models(negative)
for x in control_nets: comfy.model_management.load_controlnet_gpu(models)
control_net_models += x.get_control_models()
model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
...@@ -736,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ...@@ -736,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu() samples = samples.cpu()
for c in control_nets: for m in models:
c.cleanup() m.cleanup()
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
...@@ -1073,6 +1146,7 @@ NODE_CLASS_MAPPINGS = { ...@@ -1073,6 +1146,7 @@ NODE_CLASS_MAPPINGS = {
"VAELoader": VAELoader, "VAELoader": VAELoader,
"EmptyLatentImage": EmptyLatentImage, "EmptyLatentImage": EmptyLatentImage,
"LatentUpscale": LatentUpscale, "LatentUpscale": LatentUpscale,
"LatentFromBatch": LatentFromBatch,
"SaveImage": SaveImage, "SaveImage": SaveImage,
"PreviewImage": PreviewImage, "PreviewImage": PreviewImage,
"LoadImage": LoadImage, "LoadImage": LoadImage,
...@@ -1102,6 +1176,9 @@ NODE_CLASS_MAPPINGS = { ...@@ -1102,6 +1176,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
} }
...@@ -1178,15 +1255,16 @@ def load_custom_node(module_path): ...@@ -1178,15 +1255,16 @@ def load_custom_node(module_path):
print(f"Cannot import {module_path} module for custom nodes:", e) print(f"Cannot import {module_path} module for custom nodes:", e)
def load_custom_nodes(): def load_custom_nodes():
CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
possible_modules = os.listdir(CUSTOM_NODE_PATH) for custom_node_path in node_paths:
if "__pycache__" in possible_modules: possible_modules = os.listdir(custom_node_path)
possible_modules.remove("__pycache__") if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")
for possible_module in possible_modules:
module_path = os.path.join(CUSTOM_NODE_PATH, possible_module) for possible_module in possible_modules:
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue module_path = os.path.join(custom_node_path, possible_module)
load_custom_node(module_path) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
load_custom_node(module_path)
def init_custom_nodes(): def init_custom_nodes():
load_custom_nodes() load_custom_nodes()
......
...@@ -138,6 +138,11 @@ ...@@ -138,6 +138,11 @@
"# Controlnet Preprocessor nodes by Fannovel16\n", "# Controlnet Preprocessor nodes by Fannovel16\n",
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
"\n", "\n",
"\n",
"# GLIGEN\n",
"#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n",
"\n",
"\n",
"# ESRGAN upscale model\n", "# ESRGAN upscale model\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n",
"#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n",
......
import { app } from "/scripts/app.js";
// Allows you to edit the attention weight by holding ctrl (or cmd) and using the up/down arrow keys
app.registerExtension({
name: "Comfy.EditAttention",
init() {
const editAttentionDelta = app.ui.settings.addSetting({
id: "Comfy.EditAttention.Delta",
name: "Ctrl+up/down precision",
type: "slider",
attrs: {
min: 0.01,
max: 0.5,
step: 0.01,
},
defaultValue: 0.05,
});
function incrementWeight(weight, delta) {
const floatWeight = parseFloat(weight);
if (isNaN(floatWeight)) return weight;
const newWeight = floatWeight + delta;
if (newWeight < 0) return "0";
return String(Number(newWeight.toFixed(10)));
}
function findNearestEnclosure(text, cursorPos) {
let start = cursorPos, end = cursorPos;
let openCount = 0, closeCount = 0;
// Find opening parenthesis before cursor
while (start >= 0) {
start--;
if (text[start] === "(" && openCount === closeCount) break;
if (text[start] === "(") openCount++;
if (text[start] === ")") closeCount++;
}
if (start < 0) return false;
openCount = 0;
closeCount = 0;
// Find closing parenthesis after cursor
while (end < text.length) {
if (text[end] === ")" && openCount === closeCount) break;
if (text[end] === "(") openCount++;
if (text[end] === ")") closeCount++;
end++;
}
if (end === text.length) return false;
return { start: start + 1, end: end };
}
function addWeightToParentheses(text) {
const parenRegex = /^\((.*)\)$/;
const parenMatch = text.match(parenRegex);
const floatRegex = /:([+-]?(\d*\.)?\d+([eE][+-]?\d+)?)/;
const floatMatch = text.match(floatRegex);
if (parenMatch && !floatMatch) {
return `(${parenMatch[1]}:1.0)`;
} else {
return text;
}
};
function editAttention(event) {
const inputField = event.composedPath()[0];
const delta = parseFloat(editAttentionDelta.value);
if (inputField.tagName !== "TEXTAREA") return;
if (!(event.key === "ArrowUp" || event.key === "ArrowDown")) return;
if (!event.ctrlKey && !event.metaKey) return;
event.preventDefault();
let start = inputField.selectionStart;
let end = inputField.selectionEnd;
let selectedText = inputField.value.substring(start, end);
// If there is no selection, attempt to find the nearest enclosure, or select the current word
if (!selectedText) {
const nearestEnclosure = findNearestEnclosure(inputField.value, start);
if (nearestEnclosure) {
start = nearestEnclosure.start;
end = nearestEnclosure.end;
selectedText = inputField.value.substring(start, end);
} else {
// Select the current word, find the start and end of the word (first space before and after)
const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1;
const wordEnd = inputField.value.substring(end).indexOf(" ");
// If there is no space after the word, select to the end of the string
if (wordEnd === -1) {
end = inputField.value.length;
} else {
end += wordEnd;
}
start = wordStart;
// Remove all punctuation at the end and beginning of the word
while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) {
start++;
}
while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) {
end--;
}
selectedText = inputField.value.substring(start, end);
if (!selectedText) return;
}
}
// If the selection ends with a space, remove it
if (selectedText[selectedText.length - 1] === " ") {
selectedText = selectedText.substring(0, selectedText.length - 1);
end -= 1;
}
// If there are parentheses left and right of the selection, select them
if (inputField.value[start - 1] === "(" && inputField.value[end] === ")") {
start -= 1;
end += 1;
selectedText = inputField.value.substring(start, end);
}
// If the selection is not enclosed in parentheses, add them
if (selectedText[0] !== "(" || selectedText[selectedText.length - 1] !== ")") {
selectedText = `(${selectedText})`;
}
// If the selection does not have a weight, add a weight of 1.0
selectedText = addWeightToParentheses(selectedText);
// Increment the weight
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => {
return prefix + incrementWeight(weight, weightDelta) + suffix;
});
inputField.setRangeText(updatedText, start, end, "select");
}
window.addEventListener("keydown", editAttention);
},
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment