"web/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "028e1f7ad2a50efea8391ea54b606cf865d788db"
Commit d1d2fea8 authored by comfyanonymous's avatar comfyanonymous
Browse files

Pass extra conds directly to unet.

parent 036f88c6
......@@ -50,7 +50,7 @@ class BaseModel(torch.nn.Module):
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}, **kwargs):
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1)
else:
......@@ -60,9 +60,10 @@ class BaseModel(torch.nn.Module):
xc = xc.to(dtype)
t = t.to(dtype)
context = context.to(dtype)
if c_adm is not None:
c_adm = c_adm.to(dtype)
return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float()
extra_conds = {}
for o in kwargs:
extra_conds[o] = kwargs[o].to(dtype)
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
def get_dtype(self):
return self.diffusion_model.dtype
......@@ -107,7 +108,7 @@ class BaseModel(torch.nn.Module):
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
adm = self.encode_adm(**kwargs)
if adm is not None:
out['c_adm'] = comfy.conds.CONDRegular(adm)
out['y'] = comfy.conds.CONDRegular(adm)
return out
def load_model_weights(self, sd, unet_prefix=""):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment