Commit 3696d169 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add support for GLIGEN textbox model.

parent 472b1cc0
import torch
from torch import nn, einsum
from ldm.modules.attention import CrossAttention
from inspect import isfunction
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * torch.nn.functional.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class GatedCrossAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
x = x + self.scale * \
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim,
context_dim=query_dim,
heads=n_heads,
dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
N_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class GatedSelfAttentionDense2(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()
# we need a linear projection since we need cat visual feature and obj
# feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = CrossAttention(
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
# this can be useful: we can externally change magnitude of tanh(alpha)
# for example, when it is set to 0, then the entire model is same as
# original one
self.scale = 1
def forward(self, x, objs):
B, N_visual, _ = x.shape
B, N_ground, _ = objs.shape
objs = self.linear(objs)
# sanity check
size_v = math.sqrt(N_visual)
size_g = math.sqrt(N_ground)
assert int(size_v) == size_v, "Visual tokens must be square rootable"
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
size_v = int(size_v)
size_g = int(size_g)
# select grounding token and resize it to visual token size as residual
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
:, N_visual:, :]
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
out = torch.nn.functional.interpolate(
out, (size_v, size_v), mode='bicubic')
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
# add residual to visual feature
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
x = x + self.scale * \
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
return x
class FourierEmbedder():
def __init__(self, num_freqs=64, temperature=100):
self.num_freqs = num_freqs
self.temperature = temperature
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
@torch.no_grad()
def __call__(self, x, cat_dim=-1):
"x: arbitrary shape of tensor. dim: cat dim"
out = []
for freq in self.freq_bands:
out.append(torch.sin(freq * x))
out.append(torch.cos(freq * x))
return torch.cat(out, cat_dim)
class PositionNet(nn.Module):
def __init__(self, in_dim, out_dim, fourier_freqs=8):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
self.linears = nn.Sequential(
nn.Linear(self.in_dim + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(
torch.zeros([self.in_dim]))
self.null_position_feature = torch.nn.Parameter(
torch.zeros([self.position_dim]))
def forward(self, boxes, masks, positive_embeddings):
B, N, _ = boxes.shape
masks = masks.unsqueeze(-1)
# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * \
masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
objs = self.linears(
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
assert objs.shape == torch.Size([B, N, self.out_dim])
return objs
class Gligen(nn.Module):
def __init__(self, modules, position_net, key_dim):
super().__init__()
self.module_list = nn.ModuleList(modules)
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
def _set_position(self, boxes, masks, positive_embeddings):
objs = self.position_net(boxes, masks, positive_embeddings)
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu")
boxes = []
positive_embeddings = []
for p in position_params:
x1 = (p[4]) / w
y1 = (p[3]) / h
x2 = (p[4] + p[2]) / w
y2 = (p[3] + p[1]) / h
masks[len(boxes)] = 1.0
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
positive_embeddings += [p[0]]
append_boxes = []
append_conds = []
if len(boxes) < self.max_objs:
append_boxes = [torch.zeros(
[self.max_objs - len(boxes), 4], device="cpu")]
append_conds = [torch.zeros(
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
box_out = torch.cat(
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
masks = masks.unsqueeze(0).repeat(batch, 1)
conds = torch.cat(positive_embeddings +
append_conds).unsqueeze(0).repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def set_empty(self, latent_image_shape, device):
batch, c, h, w = latent_image_shape
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
box_out = torch.zeros([self.max_objs, 4],
device="cpu").repeat(batch, 1, 1)
conds = torch.zeros([self.max_objs, self.key_dim],
device="cpu").repeat(batch, 1, 1)
return self._set_position(
box_out.to(device),
masks.to(device),
conds.to(device))
def cleanup(self):
pass
def get_models(self):
return [self]
def load_gligen(sd):
sd_k = sd.keys()
output_list = []
key_dim = 768
for a in ["input_blocks", "middle_block", "output_blocks"]:
for b in range(20):
k_temp = filter(lambda k: "{}.{}.".format(a, b)
in k and ".fuser." in k, sd_k)
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
n_sd = {}
for k in k_temp:
n_sd[k[1]] = sd[k[0]]
if len(n_sd) > 0:
query_dim = n_sd["linear.weight"].shape[0]
key_dim = n_sd["linear.weight"].shape[1]
if key_dim == 768: # SD1.x
n_heads = 8
d_head = query_dim // n_heads
else:
d_head = 64
n_heads = query_dim // d_head
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
in_dim = sd["position_net.null_positive_feature"].shape[0]
out_dim = sd["position_net.linears.4.weight"].shape[0]
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen
...@@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module): ...@@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}): def _forward(self, x, context=None, transformer_options={}):
current_index = None
if "current_index" in transformer_options:
current_index = transformer_options["current_index"]
if "patches" in transformer_options:
transformer_patches = transformer_options["patches"]
else:
transformer_patches = {}
n = self.norm1(x) n = self.norm1(x)
if "tomesd" in transformer_options: if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
...@@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module): ...@@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module):
n = self.attn1(n, context=context if self.disable_self_attn else None) n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n x += n
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
x = p(current_index, x)
n = self.norm2(x) n = self.norm2(x)
n = self.attn2(n, context=context) n = self.attn2(n, context=context)
x += n x += n
x = self.ff(self.norm3(x)) + x x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x return x
......
...@@ -782,6 +782,8 @@ class UNetModel(nn.Module): ...@@ -782,6 +782,8 @@ class UNetModel(nn.Module):
:return: an [N x C x ...] Tensor of outputs. :return: an [N x C x ...] Tensor of outputs.
""" """
transformer_options["original_shape"] = list(x.shape) transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
assert (y is not None) == ( assert (y is not None) == (
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
......
...@@ -176,7 +176,7 @@ def load_model_gpu(model): ...@@ -176,7 +176,7 @@ def load_model_gpu(model):
model_accelerated = True model_accelerated = True
return current_loaded_model return current_loaded_model
def load_controlnet_gpu(models): def load_controlnet_gpu(control_models):
global current_gpu_controlnets global current_gpu_controlnets
global vram_state global vram_state
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
...@@ -186,6 +186,10 @@ def load_controlnet_gpu(models): ...@@ -186,6 +186,10 @@ def load_controlnet_gpu(models):
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return return
models = []
for m in control_models:
models += m.get_models()
for m in current_gpu_controlnets: for m in current_gpu_controlnets:
if m not in models: if m not in models:
m.cpu() m.cpu()
......
...@@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
control = None control = None
if 'control' in cond[1]: if 'control' in cond[1]:
control = cond[1]['control'] control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
patches = None
if 'gligen' in cond[1]:
gligen = cond[1]['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
...@@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape: if c1[0].shape != c2[0].shape:
return False return False
#control
if (c1[4] is None) != (c2[4] is None): if (c1[4] is None) != (c2[4] is None):
return False return False
if c1[4] is not None: if c1[4] is not None:
if c1[4] is not c2[4]: if c1[4] is not c2[4]:
return False return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2]) return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list): def cond_cat(c_list):
...@@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
cond_or_uncond = [] cond_or_uncond = []
area = [] area = []
control = None control = None
patches = None
for x in to_batch: for x in to_batch:
o = to_run.pop(x) o = to_run.pop(x)
p = o[0] p = o[0]
...@@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
area += [p[3]] area += [p[3]]
cond_or_uncond += [o[1]] cond_or_uncond += [o[1]]
control = p[4] control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond) batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x) input_x = torch.cat(input_x)
...@@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if control is not None: if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options: if 'transformer_options' in model_options:
c['transformer_options'] = model_options['transformer_options'] transformer_options = model_options['transformer_options'].copy()
if patches is not None:
transformer_options["patches"] = patches
c['transformer_options'] = transformer_options
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x del input_x
...@@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c): ...@@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy() n = c[1].copy()
conds += [[smallest[0], n]] conds += [[smallest[0], n]]
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
def apply_control_net_to_equal_area(conds, uncond):
cond_cnets = [] cond_cnets = []
cond_other = [] cond_other = []
uncond_cnets = [] uncond_cnets = []
...@@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond): ...@@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
if 'area' not in x[1]: if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None: if name in x[1] and x[1][name] is not None:
cond_cnets.append(x[1]['control']) cond_cnets.append(x[1][name])
else: else:
cond_other.append((x, t)) cond_other.append((x, t))
for t in range(len(uncond)): for t in range(len(uncond)):
x = uncond[t] x = uncond[t]
if 'area' not in x[1]: if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None: if name in x[1] and x[1][name] is not None:
uncond_cnets.append(x[1]['control']) uncond_cnets.append(x[1][name])
else: else:
uncond_other.append((x, t)) uncond_other.append((x, t))
...@@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond): ...@@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond):
for x in range(len(cond_cnets)): for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)] temp = uncond_other[x % len(uncond_other)]
o = temp[0] o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None: if name in o[1] and o[1][name] is not None:
n = o[1].copy() n = o[1].copy()
n['control'] = cond_cnets[x] n[name] = uncond_fill_func(cond_cnets, x)
uncond += [[o[0], n]] uncond += [[o[0], n]]
else: else:
n = o[1].copy() n = o[1].copy()
n['control'] = cond_cnets[x] n[name] = uncond_fill_func(cond_cnets, x)
uncond[temp[1]] = [o[0], n] uncond[temp[1]] = [o[0], n]
def encode_adm(noise_augmentor, conds, batch_size, device): def encode_adm(noise_augmentor, conds, batch_size, device):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
...@@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): ...@@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
return conds return conds
class KSampler: class KSampler:
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
...@@ -466,7 +498,8 @@ class KSampler: ...@@ -466,7 +498,8 @@ class KSampler:
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative) apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if self.model.model.diffusion_model.dtype == torch.float16: if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast precision_scope = torch.autocast
......
...@@ -13,6 +13,7 @@ from .t2i_adapter import adapter ...@@ -13,6 +13,7 @@ from .t2i_adapter import adapter
from . import utils from . import utils
from . import clip_vision from . import clip_vision
from . import gligen
def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]):
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
...@@ -378,7 +379,7 @@ class CLIP: ...@@ -378,7 +379,7 @@ class CLIP:
def tokenize(self, text, return_word_ids=False): def tokenize(self, text, return_word_ids=False):
return self.tokenizer.tokenize_with_weights(text, return_word_ids) return self.tokenizer.tokenize_with_weights(text, return_word_ids)
def encode_from_tokens(self, tokens): def encode_from_tokens(self, tokens, return_pooled=False):
if self.layer_idx is not None: if self.layer_idx is not None:
self.cond_stage_model.clip_layer(self.layer_idx) self.cond_stage_model.clip_layer(self.layer_idx)
try: try:
...@@ -388,6 +389,10 @@ class CLIP: ...@@ -388,6 +389,10 @@ class CLIP:
except Exception as e: except Exception as e:
self.patcher.unpatch_model() self.patcher.unpatch_model()
raise e raise e
if return_pooled:
eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__)
pooled = cond[:, eos_token_index]
return cond, pooled
return cond return cond
def encode(self, text): def encode(self, text):
...@@ -564,10 +569,10 @@ class ControlNet: ...@@ -564,10 +569,10 @@ class ControlNet:
c.strength = self.strength c.strength = self.strength
return c return c
def get_control_models(self): def get_models(self):
out = [] out = []
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models() out += self.previous_controlnet.get_models()
out.append(self.control_model) out.append(self.control_model)
return out return out
...@@ -737,10 +742,10 @@ class T2IAdapter: ...@@ -737,10 +742,10 @@ class T2IAdapter:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
def get_control_models(self): def get_models(self):
out = [] out = []
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_control_models() out += self.previous_controlnet.get_models()
return out return out
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data):
...@@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None): ...@@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None):
clip.load_from_state_dict(clip_data) clip.load_from_state_dict(clip_data)
return clip return clip
def load_gligen(ckpt_path):
data = utils.load_torch_file(ckpt_path)
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
return model
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
with open(config_path, 'r') as stream: with open(config_path, 'r') as stream:
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
......
...@@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] ...@@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")]
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
......
...@@ -490,6 +490,51 @@ class unCLIPConditioning: ...@@ -490,6 +490,51 @@ class unCLIPConditioning:
c.append(n) c.append(n)
return (c, ) return (c, )
class GLIGENLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}}
RETURN_TYPES = ("GLIGEN",)
FUNCTION = "load_gligen"
CATEGORY = "_for_testing/gligen"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
return (gligen,)
class GLIGENTextBoxApply:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning_to": ("CONDITIONING", ),
"clip": ("CLIP", ),
"gligen_textbox_model": ("GLIGEN", ),
"text": ("STRING", {"multiline": True}),
"width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "_for_testing/gligen"
def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
c = []
cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
for t in conditioning_to:
n = [t[0], t[1].copy()]
position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
prev = []
if "gligen" in n[1]:
prev = n[1]['gligen'][2]
n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
c.append(n)
return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self, device="cpu"):
...@@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ...@@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
negative_copy = [] negative_copy = []
control_nets = [] control_nets = []
def get_models(cond):
models = []
for c in cond:
if 'control' in c[1]:
models += [c[1]['control']]
if 'gligen' in c[1]:
models += [c[1]['gligen'][1]]
return models
for p in positive: for p in positive:
t = p[0] t = p[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in p[1]:
control_nets += [p[1]['control']]
positive_copy += [[t] + p[1:]] positive_copy += [[t] + p[1:]]
for n in negative: for n in negative:
t = n[0] t = n[0]
if t.shape[0] < noise.shape[0]: if t.shape[0] < noise.shape[0]:
t = torch.cat([t] * noise.shape[0]) t = torch.cat([t] * noise.shape[0])
t = t.to(device) t = t.to(device)
if 'control' in n[1]:
control_nets += [n[1]['control']]
negative_copy += [[t] + n[1:]] negative_copy += [[t] + n[1:]]
control_net_models = [] models = get_models(positive) + get_models(negative)
for x in control_nets: comfy.model_management.load_controlnet_gpu(models)
control_net_models += x.get_control_models()
comfy.model_management.load_controlnet_gpu(control_net_models)
if sampler_name in comfy.samplers.KSampler.SAMPLERS: if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
...@@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, ...@@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask)
samples = samples.cpu() samples = samples.cpu()
for c in control_nets: for m in models:
c.cleanup() m.cleanup()
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples
...@@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = { ...@@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeTiled": VAEEncodeTiled, "VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel, "TomePatchModel": TomePatchModel,
"unCLIPCheckpointLoader": unCLIPCheckpointLoader, "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
"GLIGENLoader": GLIGENLoader,
"GLIGENTextBoxApply": GLIGENTextBoxApply,
"CheckpointLoader": CheckpointLoader, "CheckpointLoader": CheckpointLoader,
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment