model_base.py 13.4 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
5
import comfy.model_management
6
import comfy.conds
7
from enum import Enum
8
import contextlib
9
from . import utils
comfyanonymous's avatar
comfyanonymous committed
10

11
12
13
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2
comfyanonymous's avatar
comfyanonymous committed
14
    V_PREDICTION_EDM = 3
15

comfyanonymous's avatar
comfyanonymous committed
16

comfyanonymous's avatar
comfyanonymous committed
17
18
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM

19

comfyanonymous's avatar
comfyanonymous committed
20
def model_sampling(model_config, model_type):
comfyanonymous's avatar
comfyanonymous committed
21
22
    s = ModelSamplingDiscrete

comfyanonymous's avatar
comfyanonymous committed
23
24
25
26
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION
comfyanonymous's avatar
comfyanonymous committed
27
28
29
    elif model_type == ModelType.V_PREDICTION_EDM:
        c = V_PREDICTION
        s = ModelSamplingContinuousEDM
comfyanonymous's avatar
comfyanonymous committed
30
31
32
33
34
35
36

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
37
class BaseModel(torch.nn.Module):
38
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
39
40
        super().__init__()

41
42
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
43
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
44

45
46
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
47
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
48
49
        self.model_sampling = model_sampling(model_config, model_type)

50
51
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
52
            self.adm_channels = 0
53
        self.inpaint_model = False
54
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
55
56
        print("adm", self.adm_channels)

57
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
58
59
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
60
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
61
62
            xc = torch.cat([xc] + [c_concat], dim=1)

63
        context = c_crossattn
64
        dtype = self.get_dtype()
65
66
67
68
69
70
71

        if comfy.model_management.supports_dtype(xc.device, dtype):
            precision_scope = lambda a: contextlib.nullcontext(a)
        else:
            precision_scope = torch.autocast
            dtype = torch.float32

72
        xc = xc.to(dtype)
73
        t = self.model_sampling.timestep(t).float()
74
        context = context.to(dtype)
75
76
        extra_conds = {}
        for o in kwargs:
77
78
79
80
            extra = kwargs[o]
            if hasattr(extra, "to"):
                extra = extra.to(dtype)
            extra_conds[o] = extra
81
82
83
84

        with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
            model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()

comfyanonymous's avatar
comfyanonymous committed
85
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
86
87
88
89
90
91
92

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

93
94
95
    def encode_adm(self, **kwargs):
        return None

96
97
    def extra_conds(self, **kwargs):
        out = {}
98
99
100
101
102
103
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
104
            device = kwargs["device"]
105
106
107
108
109
110
111
112
113
114
115
116
117

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
118
                        cond_concat.append(denoise_mask[:,:1].to(device))
119
                    elif ck == "masked_image":
120
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
121
122
123
124
125
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
126
127
128
129
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
130
            out['y'] = comfy.conds.CONDRegular(adm)
131
        return out
132

133
134
135
136
137
138
139
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

140
        to_load = self.model_config.process_unet_state_dict(to_load)
141
142
143
144
145
146
147
148
149
        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

150
151
152
153
154
155
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

156
157
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
158
159
160
161
162
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

163
164
165
166
167
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
168
169
170
171

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

172
173
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
174
    def set_inpaint(self):
175
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
176

177
178
179
    def memory_required(self, input_shape):
        if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
            #TODO: this needs to be tweaked
180
181
            area = input_shape[0] * input_shape[2] * input_shape[3]
            return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
182
183
        else:
            #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
184
            area = input_shape[0] * input_shape[2] * input_shape[3]
185
186
187
            return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)


comfyanonymous's avatar
comfyanonymous committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
211

comfyanonymous's avatar
comfyanonymous committed
212
class SD21UNCLIP(BaseModel):
213
214
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
215
216
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

217
218
219
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
220
221
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
222
        else:
comfyanonymous's avatar
comfyanonymous committed
223
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
224

225
226
227
228
229
230
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

231
class SDXLRefiner(BaseModel):
232
233
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
234
        self.embedder = Timestep(256)
235
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
236
237

    def encode_adm(self, **kwargs):
238
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
239
240
241
242
243
244
245
246
247
248
249
250
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
251
        out.append(self.embedder(torch.Tensor([width])))
252
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
253
        out.append(self.embedder(torch.Tensor([crop_w])))
254
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
255
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
256
257
258
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
259
260
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
261
        self.embedder = Timestep(256)
262
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
263
264

    def encode_adm(self, **kwargs):
265
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
266
267
268
269
270
271
272
273
274
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
275
        out.append(self.embedder(torch.Tensor([width])))
276
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
277
        out.append(self.embedder(torch.Tensor([crop_w])))
278
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
279
        out.append(self.embedder(torch.Tensor([target_width])))
280
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
281
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
comfyanonymous's avatar
comfyanonymous committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

class SVD_img2vid(BaseModel):
    def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
        super().__init__(model_config, model_type, device=device)
        self.embedder = Timestep(256)

    def encode_adm(self, **kwargs):
        fps_id = kwargs.get("fps", 6) - 1
        motion_bucket_id = kwargs.get("motion_bucket_id", 127)
        augmentation = kwargs.get("augmentation_level", 0)

        out = []
        out.append(self.embedder(torch.Tensor([fps_id])))
        out.append(self.embedder(torch.Tensor([motion_bucket_id])))
        out.append(self.embedder(torch.Tensor([augmentation])))

        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0)
        return flat

    def extra_conds(self, **kwargs):
        out = {}
        adm = self.encode_adm(**kwargs)
        if adm is not None:
            out['y'] = comfy.conds.CONDRegular(adm)

        latent_image = kwargs.get("concat_latent_image", None)
        noise = kwargs.get("noise", None)
        device = kwargs["device"]

        if latent_image is None:
            latent_image = torch.zeros_like(noise)

        if latent_image.shape[1:] != noise.shape[1:]:
            latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")

317
        latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
comfyanonymous's avatar
comfyanonymous committed
318
319
320
321
322
323
324
325
326

        out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

        if "time_conditioning" in kwargs:
            out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])

        out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device))
        out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0])
        return out