model_base.py 13.2 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
5
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
6
import comfy.model_management
7
import comfy.conds
comfyanonymous's avatar
comfyanonymous committed
8
import numpy as np
9
from enum import Enum
10
from . import utils
comfyanonymous's avatar
comfyanonymous committed
11

12
13
14
15
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2

comfyanonymous's avatar
comfyanonymous committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

#NOTE: all this sampling stuff will be moved
class EPS:
    def calculate_input(self, sigma, noise):
        sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
        return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

    def calculate_denoised(self, sigma, model_output, model_input):
        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
        return model_input - model_output * sigma


class V_PREDICTION(EPS):
    def calculate_denoised(self, sigma, model_output, model_input):
        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
        return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5


class ModelSamplingDiscrete(torch.nn.Module):
    def __init__(self, model_config):
        super().__init__()
        self._register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
        self.sigma_data = 1.0

    def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        if given_betas is not None:
            betas = given_betas
        else:
            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end

        # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
        # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
        # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

        sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)

        self.register_buffer('sigmas', sigmas)
        self.register_buffer('log_sigmas', sigmas.log())

    @property
    def sigma_min(self):
        return self.sigmas[0]

    @property
    def sigma_max(self):
        return self.sigmas[-1]

    def timestep(self, sigma):
        log_sigma = sigma.log()
        dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
        return dists.abs().argmin(dim=0).view(sigma.shape)

    def sigma(self, timestep):
        t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
        low_idx = t.floor().long()
        high_idx = t.ceil().long()
        w = t.frac()
        log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
        return log_sigma.exp()

def model_sampling(model_config, model_type):
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION

    s = ModelSamplingDiscrete

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)



comfyanonymous's avatar
comfyanonymous committed
100
class BaseModel(torch.nn.Module):
101
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
102
103
        super().__init__()

104
105
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
106
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
107

108
109
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
110
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
111
112
        self.model_sampling = model_sampling(model_config, model_type)

113
114
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
115
            self.adm_channels = 0
116
        self.inpaint_model = False
117
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
118
119
        print("adm", self.adm_channels)

120
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
121
122
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
123
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
124
125
            xc = torch.cat([xc] + [c_concat], dim=1)

126
        context = c_crossattn
127
128
        dtype = self.get_dtype()
        xc = xc.to(dtype)
comfyanonymous's avatar
comfyanonymous committed
129
        t = self.model_sampling.timestep(t).to(dtype)
130
        context = context.to(dtype)
131
132
133
        extra_conds = {}
        for o in kwargs:
            extra_conds[o] = kwargs[o].to(dtype)
comfyanonymous's avatar
comfyanonymous committed
134
135
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
136
137
138
139
140
141
142

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

143
144
145
    def encode_adm(self, **kwargs):
        return None

146
147
    def extra_conds(self, **kwargs):
        out = {}
148
149
150
151
152
153
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
154
            device = kwargs["device"]
155
156
157
158
159
160
161
162
163
164
165
166
167

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
168
                        cond_concat.append(denoise_mask[:,:1].to(device))
169
                    elif ck == "masked_image":
170
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
171
172
173
174
175
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
176
177
178
179
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
180
            out['y'] = comfy.conds.CONDRegular(adm)
181
        return out
182

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

199
200
201
202
203
204
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

205
206
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
207
208
209
210
211
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

212
213
214
215
216
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
217
218
219
220

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

221
222
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
223
    def set_inpaint(self):
224
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
225

comfyanonymous's avatar
comfyanonymous committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
249

comfyanonymous's avatar
comfyanonymous committed
250
class SD21UNCLIP(BaseModel):
251
252
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
253
254
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

255
256
257
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
258
259
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
260
        else:
comfyanonymous's avatar
comfyanonymous committed
261
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
262

263
264
265
266
267
268
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

269
class SDXLRefiner(BaseModel):
270
271
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
272
        self.embedder = Timestep(256)
273
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
274
275

    def encode_adm(self, **kwargs):
276
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
277
278
279
280
281
282
283
284
285
286
287
288
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
289
        out.append(self.embedder(torch.Tensor([width])))
290
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
291
        out.append(self.embedder(torch.Tensor([crop_w])))
292
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
293
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
294
295
296
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
297
298
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
299
        self.embedder = Timestep(256)
300
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
301
302

    def encode_adm(self, **kwargs):
303
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
304
305
306
307
308
309
310
311
312
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
313
        out.append(self.embedder(torch.Tensor([width])))
314
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
315
        out.append(self.embedder(torch.Tensor([crop_w])))
316
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
317
        out.append(self.embedder(torch.Tensor([target_width])))
318
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
319
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)