model_base.py 13 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
5
import comfy.model_management
6
import comfy.conds
7
from enum import Enum
8
from . import utils
comfyanonymous's avatar
comfyanonymous committed
9

10
11
12
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2
comfyanonymous's avatar
comfyanonymous committed
13
    V_PREDICTION_EDM = 3
14

comfyanonymous's avatar
comfyanonymous committed
15

comfyanonymous's avatar
comfyanonymous committed
16
17
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM

18

comfyanonymous's avatar
comfyanonymous committed
19
def model_sampling(model_config, model_type):
comfyanonymous's avatar
comfyanonymous committed
20
21
    s = ModelSamplingDiscrete

comfyanonymous's avatar
comfyanonymous committed
22
23
24
25
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION
comfyanonymous's avatar
comfyanonymous committed
26
27
28
    elif model_type == ModelType.V_PREDICTION_EDM:
        c = V_PREDICTION
        s = ModelSamplingContinuousEDM
comfyanonymous's avatar
comfyanonymous committed
29
30
31
32
33
34
35

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
36
class BaseModel(torch.nn.Module):
37
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
38
39
        super().__init__()

40
41
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
42
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
43

44
45
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
46
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
47
48
        self.model_sampling = model_sampling(model_config, model_type)

49
50
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
51
            self.adm_channels = 0
52
        self.inpaint_model = False
53
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
54
55
        print("adm", self.adm_channels)

56
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
57
58
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
59
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
60
61
            xc = torch.cat([xc] + [c_concat], dim=1)

62
        context = c_crossattn
63
64
        dtype = self.get_dtype()
        xc = xc.to(dtype)
65
        t = self.model_sampling.timestep(t).float()
66
        context = context.to(dtype)
67
68
        extra_conds = {}
        for o in kwargs:
69
70
71
72
            extra = kwargs[o]
            if hasattr(extra, "to"):
                extra = extra.to(dtype)
            extra_conds[o] = extra
comfyanonymous's avatar
comfyanonymous committed
73
74
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
75
76
77
78
79
80
81

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

82
83
84
    def encode_adm(self, **kwargs):
        return None

85
86
    def extra_conds(self, **kwargs):
        out = {}
87
88
89
90
91
92
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
93
            device = kwargs["device"]
94
95
96
97
98
99
100
101
102
103
104
105
106

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
107
                        cond_concat.append(denoise_mask[:,:1].to(device))
108
                    elif ck == "masked_image":
109
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
110
111
112
113
114
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
115
116
117
118
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
119
            out['y'] = comfy.conds.CONDRegular(adm)
120
        return out
121

122
123
124
125
126
127
128
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

129
        to_load = self.model_config.process_unet_state_dict(to_load)
130
131
132
133
134
135
136
137
138
        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

139
140
141
142
143
144
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

145
146
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
147
148
149
150
151
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

152
153
154
155
156
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
157
158
159
160

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

161
162
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
163
    def set_inpaint(self):
164
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
165

166
167
168
169
    def memory_required(self, input_shape):
        area = input_shape[0] * input_shape[2] * input_shape[3]
        if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
            #TODO: this needs to be tweaked
170
            return (area / (comfy.model_management.dtype_size(self.get_dtype()) * 10)) * (1024 * 1024)
171
172
173
174
175
        else:
            #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
            return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)


comfyanonymous's avatar
comfyanonymous committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
199

comfyanonymous's avatar
comfyanonymous committed
200
class SD21UNCLIP(BaseModel):
201
202
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
203
204
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

205
206
207
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
208
209
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
210
        else:
comfyanonymous's avatar
comfyanonymous committed
211
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
212

213
214
215
216
217
218
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

219
class SDXLRefiner(BaseModel):
220
221
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
222
        self.embedder = Timestep(256)
223
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
224
225

    def encode_adm(self, **kwargs):
226
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
227
228
229
230
231
232
233
234
235
236
237
238
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
239
        out.append(self.embedder(torch.Tensor([width])))
240
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
241
        out.append(self.embedder(torch.Tensor([crop_w])))
242
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
243
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
244
245
246
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
247
248
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
249
        self.embedder = Timestep(256)
250
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
251
252

    def encode_adm(self, **kwargs):
253
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
254
255
256
257
258
259
260
261
262
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
263
        out.append(self.embedder(torch.Tensor([width])))
264
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
265
        out.append(self.embedder(torch.Tensor([crop_w])))
266
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
267
        out.append(self.embedder(torch.Tensor([target_width])))
268
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
269
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
comfyanonymous's avatar
comfyanonymous committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314

class SVD_img2vid(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
        super().__init__(model_config, model_type, device=device)
        self.embedder = Timestep(256)

    def encode_adm(self, **kwargs):
        fps_id = kwargs.get("fps", 6) - 1
        motion_bucket_id = kwargs.get("motion_bucket_id", 127)
        augmentation = kwargs.get("augmentation_level", 0)

        out = []
        out.append(self.embedder(torch.Tensor([fps_id])))
        out.append(self.embedder(torch.Tensor([motion_bucket_id])))
        out.append(self.embedder(torch.Tensor([augmentation])))

        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
        return flat

    def extra_conds(self, **kwargs):
        out = {}
        adm = self.encode_adm(**kwargs)
        if adm is not None:
            out['y'] = comfy.conds.CONDRegular(adm)

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)
        device = kwargs["device"]

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

        latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0])

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

        if "time_conditioning" in kwargs:
            out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])

        out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
        out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
        return out