model_base.py 24.6 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
import torch
2
import logging
3
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
comfyanonymous's avatar
comfyanonymous committed
4
from comfy.ldm.cascade.stage_c import StageC
comfyanonymous's avatar
comfyanonymous committed
5
from comfy.ldm.cascade.stage_b import StageB
comfyanonymous's avatar
comfyanonymous committed
6
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
7
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
comfyanonymous's avatar
comfyanonymous committed
8
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
comfyanonymous's avatar
comfyanonymous committed
9
import comfy.model_management
10
import comfy.conds
11
import comfy.ops
12
from enum import Enum
13
from . import utils
comfyanonymous's avatar
comfyanonymous committed
14
import comfy.latent_formats
comfyanonymous's avatar
comfyanonymous committed
15

16
17
18
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2
comfyanonymous's avatar
comfyanonymous committed
19
    V_PREDICTION_EDM = 3
comfyanonymous's avatar
comfyanonymous committed
20
    STABLE_CASCADE = 4
21
    EDM = 5
comfyanonymous's avatar
comfyanonymous committed
22
    FLOW = 6
23

comfyanonymous's avatar
comfyanonymous committed
24

25
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling
comfyanonymous's avatar
comfyanonymous committed
26

27

comfyanonymous's avatar
comfyanonymous committed
28
def model_sampling(model_config, model_type):
comfyanonymous's avatar
comfyanonymous committed
29
30
    s = ModelSamplingDiscrete

comfyanonymous's avatar
comfyanonymous committed
31
32
33
34
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION
comfyanonymous's avatar
comfyanonymous committed
35
36
37
    elif model_type == ModelType.V_PREDICTION_EDM:
        c = V_PREDICTION
        s = ModelSamplingContinuousEDM
comfyanonymous's avatar
comfyanonymous committed
38
39
40
    elif model_type == ModelType.FLOW:
        c = comfy.model_sampling.CONST
        s = comfy.model_sampling.ModelSamplingDiscreteFlow
comfyanonymous's avatar
comfyanonymous committed
41
42
43
    elif model_type == ModelType.STABLE_CASCADE:
        c = EPS
        s = StableCascadeSampling
44
45
46
    elif model_type == ModelType.EDM:
        c = EDM
        s = ModelSamplingContinuousEDM
comfyanonymous's avatar
comfyanonymous committed
47
48
49
50
51
52
53

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
54
class BaseModel(torch.nn.Module):
comfyanonymous's avatar
comfyanonymous committed
55
    def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
comfyanonymous's avatar
comfyanonymous committed
56
57
        super().__init__()

58
59
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
60
        self.model_config = model_config
61
        self.manual_cast_dtype = model_config.manual_cast_dtype
comfyanonymous's avatar
comfyanonymous committed
62

63
        if not unet_config.get("disable_unet_model_creation", False):
64
65
66
            if self.manual_cast_dtype is not None:
                operations = comfy.ops.manual_cast
            else:
comfyanonymous's avatar
comfyanonymous committed
67
                operations = comfy.ops.disable_weight_init
comfyanonymous's avatar
comfyanonymous committed
68
            self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
69
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
70
71
        self.model_sampling = model_sampling(model_config, model_type)

72
73
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
74
            self.adm_channels = 0
75
76

        self.concat_keys = ()
comfyanonymous's avatar
comfyanonymous committed
77
78
        logging.info("model_type {}".format(model_type.name))
        logging.debug("adm {}".format(self.adm_channels))
comfyanonymous's avatar
comfyanonymous committed
79

80
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
81
82
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
83
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
84
85
            xc = torch.cat([xc] + [c_concat], dim=1)

86
        context = c_crossattn
87
        dtype = self.get_dtype()
88

89
90
        if self.manual_cast_dtype is not None:
            dtype = self.manual_cast_dtype
91

92
        xc = xc.to(dtype)
93
        t = self.model_sampling.timestep(t).float()
94
        context = context.to(dtype)
95
96
        extra_conds = {}
        for o in kwargs:
97
            extra = kwargs[o]
98
99
100
            if hasattr(extra, "dtype"):
                if extra.dtype != torch.int and extra.dtype != torch.long:
                    extra = extra.to(dtype)
101
            extra_conds[o] = extra
102

103
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
comfyanonymous's avatar
comfyanonymous committed
104
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
105
106
107
108
109
110
111

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

112
113
114
    def encode_adm(self, **kwargs):
        return None

115
116
    def extra_conds(self, **kwargs):
        out = {}
117
        if len(self.concat_keys) > 0:
118
            cond_concat = []
119
120
121
122
123
124
125
            denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
            concat_latent_image = kwargs.get("concat_latent_image", None)
            if concat_latent_image is None:
                concat_latent_image = kwargs.get("latent_image", None)
            else:
                concat_latent_image = self.process_latent_in(concat_latent_image)

126
            noise = kwargs.get("noise", None)
127
            device = kwargs["device"]
128

129
130
131
132
133
            if concat_latent_image.shape[1:] != noise.shape[1:]:
                concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

            concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])

134
135
136
            if denoise_mask is not None:
                if len(denoise_mask.shape) == len(noise.shape):
                    denoise_mask = denoise_mask[:,:1]
137

138
139
140
141
                denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
                if denoise_mask.shape[-2:] != noise.shape[-2:]:
                    denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
                denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
142

143
            for ck in self.concat_keys:
144
145
                if denoise_mask is not None:
                    if ck == "mask":
146
                        cond_concat.append(denoise_mask.to(device))
147
                    elif ck == "masked_image":
148
                        cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
149
150
151
152
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
153
                        cond_concat.append(self.blank_inpaint_image_like(noise))
154
155
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
156

157
158
        adm = self.encode_adm(**kwargs)
        if adm is not None:
159
            out['y'] = comfy.conds.CONDRegular(adm)
160
161
162
163
164

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)

165
166
167
168
        cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
        if cross_attn_cnet is not None:
            out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)

169
170
        c_concat = kwargs.get("noise_concat", None)
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
171
            out['c_concat'] = comfy.conds.CONDNoiseShape(c_concat)
172

173
        return out
174

175
176
177
178
179
180
181
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

182
        to_load = self.model_config.process_unet_state_dict(to_load)
183
184
        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
185
            logging.warning("unet missing: {}".format(m))
186
187

        if len(u) > 0:
188
            logging.warning("unet unexpected: {}".format(u))
189
190
191
        del to_load
        return self

192
193
194
195
196
197
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

198
199
200
201
202
203
204
205
206
    def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
        extra_sds = []
        if clip_state_dict is not None:
            extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
        if vae_state_dict is not None:
            extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
        if clip_vision_state_dict is not None:
            extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

207
        unet_state_dict = self.diffusion_model.state_dict()
208
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
209

210
211
212
        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

213
214
215
216
        for sd in extra_sds:
            unet_state_dict.update(sd)

        return unet_state_dict
217

comfyanonymous's avatar
comfyanonymous committed
218
    def set_inpaint(self):
219
220
221
222
223
224
225
226
227
228
        self.concat_keys = ("mask", "masked_image")
        def blank_inpaint_image_like(latent_image):
            blank_image = torch.ones_like(latent_image)
            # these are the values for "zero" in pixel space translated to latent space
            blank_image[:,0] *= 0.8223
            blank_image[:,1] *= -0.6876
            blank_image[:,2] *= 0.6364
            blank_image[:,3] *= 0.1380
            return blank_image
        self.blank_inpaint_image_like = blank_inpaint_image_like
comfyanonymous's avatar
comfyanonymous committed
229

230
231
    def memory_required(self, input_shape):
        if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
232
233
234
            dtype = self.get_dtype()
            if self.manual_cast_dtype is not None:
                dtype = self.manual_cast_dtype
235
            #TODO: this needs to be tweaked
236
            area = input_shape[0] * input_shape[2] * input_shape[3]
237
            return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
238
239
        else:
            #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
240
            area = input_shape[0] * input_shape[2] * input_shape[3]
241
242
243
            return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)


244
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
comfyanonymous's avatar
comfyanonymous committed
245
246
247
248
249
250
251
252
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
253
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed)
comfyanonymous's avatar
comfyanonymous committed
254
255
256
257
258
259
260
261
262
263
264
265
266
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
267

comfyanonymous's avatar
comfyanonymous committed
268
class SD21UNCLIP(BaseModel):
269
270
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
271
272
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

273
274
275
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
276
277
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
278
        else:
279
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
280

281
282
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
283
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280]
284
285
286
    else:
        return args["pooled_output"]

287
class SDXLRefiner(BaseModel):
288
289
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
290
        self.embedder = Timestep(256)
291
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
292
293

    def encode_adm(self, **kwargs):
294
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
295
296
297
298
299
300
301
302
303
304
305
306
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
307
        out.append(self.embedder(torch.Tensor([width])))
308
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
309
        out.append(self.embedder(torch.Tensor([crop_w])))
310
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
311
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
312
313
314
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
315
316
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
317
        self.embedder = Timestep(256)
318
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
319
320

    def encode_adm(self, **kwargs):
321
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
322
323
324
325
326
327
328
329
330
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
331
        out.append(self.embedder(torch.Tensor([width])))
332
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
333
        out.append(self.embedder(torch.Tensor([crop_w])))
334
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
335
        out.append(self.embedder(torch.Tensor([target_width])))
336
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
337
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
comfyanonymous's avatar
comfyanonymous committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

class SVD_img2vid(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
        super().__init__(model_config, model_type, device=device)
        self.embedder = Timestep(256)

    def encode_adm(self, **kwargs):
        fps_id = kwargs.get("fps", 6) - 1
        motion_bucket_id = kwargs.get("motion_bucket_id", 127)
        augmentation = kwargs.get("augmentation_level", 0)

        out = []
        out.append(self.embedder(torch.Tensor([fps_id])))
        out.append(self.embedder(torch.Tensor([motion_bucket_id])))
        out.append(self.embedder(torch.Tensor([augmentation])))

        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
        return flat

    def extra_conds(self, **kwargs):
        out = {}
        adm = self.encode_adm(**kwargs)
        if adm is not None:
            out['y'] = comfy.conds.CONDRegular(adm)

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)
        device = kwargs["device"]

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

373
        latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
comfyanonymous's avatar
comfyanonymous committed
374
375
376

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

377
378
379
380
        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)

comfyanonymous's avatar
comfyanonymous committed
381
382
383
384
385
        if "time_conditioning" in kwargs:
            out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])

        out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
        return out
386

comfyanonymous's avatar
comfyanonymous committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
class SV3D_u(SVD_img2vid):
    def encode_adm(self, **kwargs):
        augmentation = kwargs.get("augmentation_level", 0)

        out = []
        out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))

        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
        return flat

class SV3D_p(SVD_img2vid):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
        super().__init__(model_config, model_type, device=device)
        self.embedder_512 = Timestep(512)

    def encode_adm(self, **kwargs):
        augmentation = kwargs.get("augmentation_level", 0)
        elevation = kwargs.get("elevation", 0) #elevation and azimuth are in degrees here
        azimuth = kwargs.get("azimuth", 0)
        noise = kwargs.get("noise", None)

        out = []
        out.append(self.embedder(torch.flatten(torch.Tensor([augmentation]))))
        out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0))))
        out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0))))

        out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
        return torch.cat(out, dim=1)


417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
class Stable_Zero123(BaseModel):
    def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
        super().__init__(model_config, model_type, device=device)
        self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
        self.cc_projection.weight.copy_(cc_projection_weight)
        self.cc_projection.bias.copy_(cc_projection_bias)

    def extra_conds(self, **kwargs):
        out = {}

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

        latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            if cross_attn.shape[-1] != 768:
                cross_attn = self.cc_projection(cross_attn)
            out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
        return out
446
447
448
449

class SD_X4Upscaler(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
450
        self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
451
452
453
454
455
456

    def extra_conds(self, **kwargs):
        out = {}

        image = kwargs.get("concat_image", None)
        noise = kwargs.get("noise", None)
457
458
459
460
461
        noise_augment = kwargs.get("noise_augmentation", 0.0)
        device = kwargs["device"]
        seed = kwargs["seed"] - 10

        noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
462
463
464
465
466

        if image is None:
            image = torch.zeros_like(noise)[:,:3]

        if image.shape[1:] != noise.shape[1:]:
467
468
469
470
471
            image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")

        noise_level = torch.tensor([noise_level], device=device)
        if noise_augment > 0:
            image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
472
473
474
475

        image = utils.resize_to_batch_size(image, noise.shape[0])

        out['c_concat'] = comfy.conds.CONDNoiseShape(image)
476
        out['y'] = comfy.conds.CONDRegular(noise_level)
477
        return out
comfyanonymous's avatar
comfyanonymous committed
478

comfyanonymous's avatar
comfyanonymous committed
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
class IP2P:
    def extra_conds(self, **kwargs):
        out = {}

        image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)
        device = kwargs["device"]

        if image is None:
            image = torch.zeros_like(noise)

        if image.shape[1:] != noise.shape[1:]:
            image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")

        image = utils.resize_to_batch_size(image, noise.shape[0])

        out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
        adm = self.encode_adm(**kwargs)
        if adm is not None:
            out['y'] = comfy.conds.CONDRegular(adm)
        return out

class SD15_instructpix2pix(IP2P, BaseModel):
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
        self.process_ip2p_image_in = lambda image: image

class SDXL_instructpix2pix(IP2P, SDXL):
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
509
510
511
512
        if model_type == ModelType.V_PREDICTION_EDM:
            self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
        else:
            self.process_ip2p_image_in = lambda image: image #diffusers ip2p
comfyanonymous's avatar
comfyanonymous committed
513
514


comfyanonymous's avatar
comfyanonymous committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
class StableCascade_C(BaseModel):
    def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
        super().__init__(model_config, model_type, device=device, unet_model=StageC)
        self.diffusion_model.eval().requires_grad_(False)

    def extra_conds(self, **kwargs):
        out = {}
        clip_text_pooled = kwargs["pooled_output"]
        if clip_text_pooled is not None:
            out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)

        if "unclip_conditioning" in kwargs:
            embeds = []
            for unclip_cond in kwargs["unclip_conditioning"]:
                weight = unclip_cond["strength"]
                embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight)
            clip_img = torch.cat(embeds, dim=1)
        else:
            clip_img = torch.zeros((1, 1, 768))
        out["clip_img"] = comfy.conds.CONDRegular(clip_img)
        out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
        out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,)))

        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
        return out

comfyanonymous's avatar
comfyanonymous committed
543
544
545
546
547
548
549
550
551
552
553
554

class StableCascade_B(BaseModel):
    def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
        super().__init__(model_config, model_type, device=device, unet_model=StageB)
        self.diffusion_model.eval().requires_grad_(False)

    def extra_conds(self, **kwargs):
        out = {}
        noise = kwargs.get("noise", None)

        clip_text_pooled = kwargs["pooled_output"]
        if clip_text_pooled is not None:
comfyanonymous's avatar
comfyanonymous committed
555
            out['clip'] = comfy.conds.CONDRegular(clip_text_pooled)
comfyanonymous's avatar
comfyanonymous committed
556
557
558
559
560
561
562

        #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
        prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))

        out["effnet"] = comfy.conds.CONDRegular(prior)
        out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
        return out
comfyanonymous's avatar
comfyanonymous committed
563
564
565
566
567
568
569
570
571
572


class SD3(BaseModel):
    def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
        super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)

    def encode_adm(self, **kwargs):
        return kwargs["pooled_output"]

    def extra_conds(self, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
573
        out = super().extra_conds(**kwargs)
comfyanonymous's avatar
comfyanonymous committed
574
575
576
577
        cross_attn = kwargs.get("cross_attn", None)
        if cross_attn is not None:
            out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
        return out
578
579
580
581
582
583
584
585
586
587
588
589

    def memory_required(self, input_shape):
        if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
            dtype = self.get_dtype()
            if self.manual_cast_dtype is not None:
                dtype = self.manual_cast_dtype
            #TODO: this probably needs to be tweaked
            area = input_shape[0] * input_shape[2] * input_shape[3]
            return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
        else:
            area = input_shape[0] * input_shape[2] * input_shape[3]
            return (area * 0.3) * (1024 * 1024)