model_base.py 20.6 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
import torch
2
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
comfyanonymous's avatar
comfyanonymous committed
3
from comfy.ldm.cascade.stage_c import StageC
comfyanonymous's avatar
comfyanonymous committed
4
from comfy.ldm.cascade.stage_b import StageB
comfyanonymous's avatar
comfyanonymous committed
5
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
6
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
comfyanonymous's avatar
comfyanonymous committed
7
import comfy.model_management
8
import comfy.conds
9
import comfy.ops
10
from enum import Enum
11
from . import utils
comfyanonymous's avatar
comfyanonymous committed
12

13
14
15
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2
comfyanonymous's avatar
comfyanonymous committed
16
    V_PREDICTION_EDM = 3
comfyanonymous's avatar
comfyanonymous committed
17
    STABLE_CASCADE = 4
18

comfyanonymous's avatar
comfyanonymous committed
19

comfyanonymous's avatar
comfyanonymous committed
20
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
comfyanonymous's avatar
comfyanonymous committed
21

22

comfyanonymous's avatar
comfyanonymous committed
23
def model_sampling(model_config, model_type):
comfyanonymous's avatar
comfyanonymous committed
24
25
    s = ModelSamplingDiscrete

comfyanonymous's avatar
comfyanonymous committed
26
27
28
29
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION
comfyanonymous's avatar
comfyanonymous committed
30
31
32
    elif model_type == ModelType.V_PREDICTION_EDM:
        c = V_PREDICTION
        s = ModelSamplingContinuousEDM
comfyanonymous's avatar
comfyanonymous committed
33
34
35
    elif model_type == ModelType.STABLE_CASCADE:
        c = EPS
        s = StableCascadeSampling
comfyanonymous's avatar
comfyanonymous committed
36
37
38
39
40
41
42

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
43
class BaseModel(torch.nn.Module):
comfyanonymous's avatar
comfyanonymous committed
44
    def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
comfyanonymous's avatar
comfyanonymous committed
45
46
        super().__init__()

47
48
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
49
        self.model_config = model_config
50
        self.manual_cast_dtype = model_config.manual_cast_dtype
comfyanonymous's avatar
comfyanonymous committed
51

52
        if not unet_config.get("disable_unet_model_creation", False):
53
54
55
            if self.manual_cast_dtype is not None:
                operations = comfy.ops.manual_cast
            else:
comfyanonymous's avatar
comfyanonymous committed
56
                operations = comfy.ops.disable_weight_init
comfyanonymous's avatar
comfyanonymous committed
57
            self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
58
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
59
60
        self.model_sampling = model_sampling(model_config, model_type)

61
62
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
63
            self.adm_channels = 0
64
        self.inpaint_model = False
65
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
66
67
        print("adm", self.adm_channels)

68
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
69
70
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
71
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
72
73
            xc = torch.cat([xc] + [c_concat], dim=1)

74
        context = c_crossattn
75
        dtype = self.get_dtype()
76

77
78
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype
79

80
        xc = xc.to(dtype)
81
        t = self.model_sampling.timestep(t).float()
82
        context = context.to(dtype)
83
84
        extra_conds = {}
        for o in kwargs:
85
            extra = kwargs[o]
86
87
88
            if hasattr(extra, "dtype"):
                if extra.dtype != torch.int and extra.dtype != torch.long:
                    extra = extra.to(dtype)
89
            extra_conds[o] = extra
90

91
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
comfyanonymous's avatar
comfyanonymous committed
92
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
93
94
95
96
97
98
99

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

100
101
102
    def encode_adm(self, **kwargs):
        return None

103
104
    def extra_conds(self, **kwargs):
        out = {}
105
106
107
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
108
109
110
111
112
113
114
            denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
            concat_latent_image = kwargs.get("concat_latent_image", None)
            if concat_latent_image is None:
                concat_latent_image = kwargs.get("latent_image", None)
            else:
                concat_latent_image = self.process_latent_in(concat_latent_image)

115
            noise = kwargs.get("noise", None)
116
            device = kwargs["device"]
117

118
119
120
121
122
123
124
125
126
127
128
129
130
            if concat_latent_image.shape[1:] != noise.shape[1:]:
                concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

            concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])

            if len(denoise_mask.shape) == len(noise.shape):
                denoise_mask = denoise_mask[:,:1]

            denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
            if denoise_mask.shape[-2:] != noise.shape[-2:]:
                denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
            denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])

131
132
133
134
135
136
137
138
139
140
141
142
            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
143
                        cond_concat.append(denoise_mask.to(device))
144
                    elif ck == "masked_image":
145
                        cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
146
147
148
149
150
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
151
152
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
153

154
155
        adm = self.encode_adm(**kwargs)
        if adm is not None:
156
            out['y'] = comfy.conds.CONDRegular(adm)
157
158
159
160
161

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)

162
163
164
165
        cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
        if cross_attn_cnet is not None:
            out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)

166
        return out
167

168
169
170
171
172
173
174
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

175
        to_load = self.model_config.process_unet_state_dict(to_load)
176
177
178
179
180
181
182
183
184
        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

185
186
187
188
189
190
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

191
192
193
194
195
196
197
198
199
    def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
        extra_sds = []
        if clip_state_dict is not None:
            extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
        if vae_state_dict is not None:
            extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
        if clip_vision_state_dict is not None:
            extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

200
        unet_state_dict = self.diffusion_model.state_dict()
201
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
202

203
        if self.get_dtype() == torch.float16:
204
            extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)
205
206
207
208

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

209
210
211
212
        for sd in extra_sds:
            unet_state_dict.update(sd)

        return unet_state_dict
213

comfyanonymous's avatar
comfyanonymous committed
214
    def set_inpaint(self):
215
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
216

217
218
    def memory_required(self, input_shape):
        if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
219
220
221
            dtype = self.get_dtype()
            if self.manual_cast_dtype is not None:
                dtype = self.manual_cast_dtype
222
            #TODO: this needs to be tweaked
223
            area = input_shape[0] * input_shape[2] * input_shape[3]
224
            return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
225
226
        else:
            #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
227
            area = input_shape[0] * input_shape[2] * input_shape[3]
228
229
230
            return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)


231
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
comfyanonymous's avatar
comfyanonymous committed
232
233
234
235
236
237
238
239
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
240
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
comfyanonymous's avatar
comfyanonymous committed
241
242
243
244
245
246
247
248
249
250
251
252
253
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
254

comfyanonymous's avatar
comfyanonymous committed
255
class SD21UNCLIP(BaseModel):
256
257
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
258
259
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

260
261
262
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
263
264
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
265
        else:
266
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
267

268
269
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
270
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
271
272
273
    else:
        return args["pooled_output"]

274
class SDXLRefiner(BaseModel):
275
276
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
277
        self.embedder = Timestep(256)
278
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
279
280

    def encode_adm(self, **kwargs):
281
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
282
283
284
285
286
287
288
289
290
291
292
293
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
294
        out.append(self.embedder(torch.Tensor([width])))
295
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
296
        out.append(self.embedder(torch.Tensor([crop_w])))
297
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
298
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
299
300
301
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
302
303
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
304
        self.embedder = Timestep(256)
305
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
306
307

    def encode_adm(self, **kwargs):
308
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
309
310
311
312
313
314
315
316
317
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
318
        out.append(self.embedder(torch.Tensor([width])))
319
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
320
        out.append(self.embedder(torch.Tensor([crop_w])))
321
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
322
        out.append(self.embedder(torch.Tensor([target_width])))
323
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
324
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
comfyanonymous's avatar
comfyanonymous committed
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

class SVD_img2vid(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
        super().__init__(model_config, model_type, device=device)
        self.embedder = Timestep(256)

    def encode_adm(self, **kwargs):
        fps_id = kwargs.get("fps", 6) - 1
        motion_bucket_id = kwargs.get("motion_bucket_id", 127)
        augmentation = kwargs.get("augmentation_level", 0)

        out = []
        out.append(self.embedder(torch.Tensor([fps_id])))
        out.append(self.embedder(torch.Tensor([motion_bucket_id])))
        out.append(self.embedder(torch.Tensor([augmentation])))

        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
        return flat

    def extra_conds(self, **kwargs):
        out = {}
        adm = self.encode_adm(**kwargs)
        if adm is not None:
            out['y'] = comfy.conds.CONDRegular(adm)

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)
        device = kwargs["device"]

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

360
        latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
comfyanonymous's avatar
comfyanonymous committed
361
362
363

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

364
365
366
367
        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)

comfyanonymous's avatar
comfyanonymous committed
368
369
370
371
372
373
        if "time_conditioning" in kwargs:
            out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])

        out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
        out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
        return out
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403

class Stable_Zero123(BaseModel):
    def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
        super().__init__(model_config, model_type, device=device)
        self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
        self.cc_projection.weight.copy_(cc_projection_weight)
        self.cc_projection.bias.copy_(cc_projection_bias)

    def extra_conds(self, **kwargs):
        out = {}

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

        latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            if cross_attn.shape[-1] != 768:
                cross_attn = self.cc_projection(cross_attn)
            out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
        return out
404
405
406
407

class SD_X4Upscaler(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
408
        self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
409
410
411
412
413
414

    def extra_conds(self, **kwargs):
        out = {}

        image = kwargs.get("concat_image", None)
        noise = kwargs.get("noise", None)
415
416
417
418
419
        noise_augment = kwargs.get("noise_augmentation", 0.0)
        device = kwargs["device"]
        seed = kwargs["seed"] - 10

        noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
420
421
422
423
424

        if image is None:
            image = torch.zeros_like(noise)[:,:3]

        if image.shape[1:] != noise.shape[1:]:
425
426
427
428
429
            image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")

        noise_level = torch.tensor([noise_level], device=device)
        if noise_augment > 0:
            image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
430
431
432
433

        image = utils.resize_to_batch_size(image, noise.shape[0])

        out['c_concat'] = comfy.conds.CONDNoiseShape(image)
434
        out['y'] = comfy.conds.CONDRegular(noise_level)
435
        return out
comfyanonymous's avatar
comfyanonymous committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

class StableCascade_C(BaseModel):
    def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
        super().__init__(model_config, model_type, device=device, unet_model=StageC)
        self.diffusion_model.eval().requires_grad_(False)

    def extra_conds(self, **kwargs):
        out = {}
        clip_text_pooled = kwargs["pooled_output"]
        if clip_text_pooled is not None:
            out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)

        if "unclip_conditioning" in kwargs:
            embeds = []
            for unclip_cond in kwargs["unclip_conditioning"]:
                weight = unclip_cond["strength"]
                embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
            clip_img = torch.cat(embeds, dim=1)
        else:
            clip_img = torch.zeros((1, 1, 768))
        out["clip_img"] = comfy.conds.CONDRegular(clip_img)
        out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
        out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
        return out

comfyanonymous's avatar
comfyanonymous committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488

class StableCascade_B(BaseModel):
    def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
        super().__init__(model_config, model_type, device=device, unet_model=StageB)
        self.diffusion_model.eval().requires_grad_(False)

    def extra_conds(self, **kwargs):
        out = {}
        noise = kwargs.get("noise", None)

        clip_text_pooled = kwargs["pooled_output"]
        if clip_text_pooled is not None:
            out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)

        #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
        prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))

        out["effnet"] = comfy.conds.CONDRegular(prior)
        out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['clip'] = comfy.conds.CONDCrossAttn(cross_attn)
        return out