model_base.py 13 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
5
import comfy.model_management
6
import comfy.conds
7
from enum import Enum
8
from . import utils
comfyanonymous's avatar
comfyanonymous committed
9

10
11
12
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2
comfyanonymous's avatar
comfyanonymous committed
13
    V_PREDICTION_EDM = 3
14

comfyanonymous's avatar
comfyanonymous committed
15

comfyanonymous's avatar
comfyanonymous committed
16
17
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM

18

comfyanonymous's avatar
comfyanonymous committed
19
def model_sampling(model_config, model_type):
comfyanonymous's avatar
comfyanonymous committed
20
21
    s = ModelSamplingDiscrete

comfyanonymous's avatar
comfyanonymous committed
22
23
24
25
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION
comfyanonymous's avatar
comfyanonymous committed
26
27
28
    elif model_type == ModelType.V_PREDICTION_EDM:
        c = V_PREDICTION
        s = ModelSamplingContinuousEDM
comfyanonymous's avatar
comfyanonymous committed
29
30
31
32
33
34
35

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
36
class BaseModel(torch.nn.Module):
37
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
38
39
        super().__init__()

40
41
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
42
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
43

44
45
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
46
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
47
48
        self.model_sampling = model_sampling(model_config, model_type)

49
50
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
51
            self.adm_channels = 0
52
        self.inpaint_model = False
53
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
54
55
        print("adm", self.adm_channels)

56
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
57
58
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
59
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
60
61
            xc = torch.cat([xc] + [c_concat], dim=1)

62
        context = c_crossattn
63
64
        dtype = self.get_dtype()
        xc = xc.to(dtype)
65
        t = self.model_sampling.timestep(t).float()
66
        context = context.to(dtype)
67
68
        extra_conds = {}
        for o in kwargs:
69
70
71
72
            extra = kwargs[o]
            if hasattr(extra, "to"):
                extra = extra.to(dtype)
            extra_conds[o] = extra
comfyanonymous's avatar
comfyanonymous committed
73
74
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
75
76
77
78
79
80
81

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

82
83
84
    def encode_adm(self, **kwargs):
        return None

85
86
    def extra_conds(self, **kwargs):
        out = {}
87
88
89
90
91
92
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
93
            device = kwargs["device"]
94
95
96
97
98
99
100
101
102
103
104
105
106

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
107
                        cond_concat.append(denoise_mask[:,:1].to(device))
108
                    elif ck == "masked_image":
109
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
110
111
112
113
114
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
115
116
117
118
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
119
            out['y'] = comfy.conds.CONDRegular(adm)
120
        return out
121

122
123
124
125
126
127
128
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

129
        to_load = self.model_config.process_unet_state_dict(to_load)
130
131
132
133
134
135
136
137
138
        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

139
140
141
142
143
144
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

145
146
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
147
148
149
150
151
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

152
153
154
155
156
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
157
158
159
160

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

161
162
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
163
    def set_inpaint(self):
164
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
165

166
167
168
    def memory_required(self, input_shape):
        if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
            #TODO: this needs to be tweaked
169
170
            area = input_shape[0] * input_shape[2] * input_shape[3]
            return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
171
172
        else:
            #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
173
            area = input_shape[0] * input_shape[2] * input_shape[3]
174
175
176
            return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)


comfyanonymous's avatar
comfyanonymous committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
200

comfyanonymous's avatar
comfyanonymous committed
201
class SD21UNCLIP(BaseModel):
202
203
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
204
205
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

206
207
208
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
209
210
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
211
        else:
comfyanonymous's avatar
comfyanonymous committed
212
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
213

214
215
216
217
218
219
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

220
class SDXLRefiner(BaseModel):
221
222
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
223
        self.embedder = Timestep(256)
224
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
225
226

    def encode_adm(self, **kwargs):
227
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
228
229
230
231
232
233
234
235
236
237
238
239
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
240
        out.append(self.embedder(torch.Tensor([width])))
241
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
242
        out.append(self.embedder(torch.Tensor([crop_w])))
243
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
244
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
245
246
247
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
248
249
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
250
        self.embedder = Timestep(256)
251
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
252
253

    def encode_adm(self, **kwargs):
254
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
255
256
257
258
259
260
261
262
263
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
264
        out.append(self.embedder(torch.Tensor([width])))
265
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
266
        out.append(self.embedder(torch.Tensor([crop_w])))
267
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
268
        out.append(self.embedder(torch.Tensor([target_width])))
269
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
270
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
comfyanonymous's avatar
comfyanonymous committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

class SVD_img2vid(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
        super().__init__(model_config, model_type, device=device)
        self.embedder = Timestep(256)

    def encode_adm(self, **kwargs):
        fps_id = kwargs.get("fps", 6) - 1
        motion_bucket_id = kwargs.get("motion_bucket_id", 127)
        augmentation = kwargs.get("augmentation_level", 0)

        out = []
        out.append(self.embedder(torch.Tensor([fps_id])))
        out.append(self.embedder(torch.Tensor([motion_bucket_id])))
        out.append(self.embedder(torch.Tensor([augmentation])))

        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
        return flat

    def extra_conds(self, **kwargs):
        out = {}
        adm = self.encode_adm(**kwargs)
        if adm is not None:
            out['y'] = comfy.conds.CONDRegular(adm)

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)
        device = kwargs["device"]

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

306
        latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
comfyanonymous's avatar
comfyanonymous committed
307
308
309
310
311
312
313
314
315

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

        if "time_conditioning" in kwargs:
            out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])

        out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
        out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
        return out