"web/vscode:/vscode.git/clone" did not exist on "1a74611c6e725f1ffb6629d08fbd04bb658f2704"
Commit 064d7583 authored by comfyanonymous's avatar comfyanonymous
Browse files

Add a CONDConstant for passing non tensor conds to unet.

parent 794dd206
......@@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond
......@@ -61,7 +61,10 @@ class BaseModel(torch.nn.Module):
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra_conds[o] = kwargs[o].to(dtype)
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment