model_base.py 13.3 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
4
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
5
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
6
import comfy.model_management
7
import comfy.conds
comfyanonymous's avatar
comfyanonymous committed
8
import numpy as np
9
from enum import Enum
10
from . import utils
comfyanonymous's avatar
comfyanonymous committed
11

12
13
14
15
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2

comfyanonymous's avatar
comfyanonymous committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

#NOTE: all this sampling stuff will be moved
class EPS:
    def calculate_input(self, sigma, noise):
        sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
        return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

    def calculate_denoised(self, sigma, model_output, model_input):
        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
        return model_input - model_output * sigma


class V_PREDICTION(EPS):
    def calculate_denoised(self, sigma, model_output, model_input):
        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
        return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5


class ModelSamplingDiscrete(torch.nn.Module):
    def __init__(self, model_config):
        super().__init__()
        self._register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
        self.sigma_data = 1.0

    def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        if given_betas is not None:
            betas = given_betas
        else:
            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end

        # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
        # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
        # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

        sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)

        self.register_buffer('sigmas', sigmas)
        self.register_buffer('log_sigmas', sigmas.log())

    @property
    def sigma_min(self):
        return self.sigmas[0]

    @property
    def sigma_max(self):
        return self.sigmas[-1]

    def timestep(self, sigma):
        log_sigma = sigma.log()
        dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
        return dists.abs().argmin(dim=0).view(sigma.shape)

    def sigma(self, timestep):
        t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
        low_idx = t.floor().long()
        high_idx = t.ceil().long()
        w = t.frac()
        log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
        return log_sigma.exp()

85
86
87
    def percent_to_sigma(self, percent):
        return self.sigma(torch.tensor(percent * 999.0))

comfyanonymous's avatar
comfyanonymous committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def model_sampling(model_config, model_type):
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION

    s = ModelSamplingDiscrete

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)



comfyanonymous's avatar
comfyanonymous committed
103
class BaseModel(torch.nn.Module):
104
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
105
106
        super().__init__()

107
108
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
109
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
110

111
112
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
113
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
114
115
        self.model_sampling = model_sampling(model_config, model_type)

116
117
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
118
            self.adm_channels = 0
119
        self.inpaint_model = False
120
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
121
122
        print("adm", self.adm_channels)

123
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
124
125
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
126
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
127
128
            xc = torch.cat([xc] + [c_concat], dim=1)

129
        context = c_crossattn
130
131
        dtype = self.get_dtype()
        xc = xc.to(dtype)
132
        t = self.model_sampling.timestep(t).float()
133
        context = context.to(dtype)
134
135
136
        extra_conds = {}
        for o in kwargs:
            extra_conds[o] = kwargs[o].to(dtype)
comfyanonymous's avatar
comfyanonymous committed
137
138
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
139
140
141
142
143
144
145

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

146
147
148
    def encode_adm(self, **kwargs):
        return None

149
150
    def extra_conds(self, **kwargs):
        out = {}
151
152
153
154
155
156
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
157
            device = kwargs["device"]
158
159
160
161
162
163
164
165
166
167
168
169
170

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
171
                        cond_concat.append(denoise_mask[:,:1].to(device))
172
                    elif ck == "masked_image":
173
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
174
175
176
177
178
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
179
180
181
182
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
183
            out['y'] = comfy.conds.CONDRegular(adm)
184
        return out
185

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

202
203
204
205
206
207
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

208
209
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
210
211
212
213
214
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

215
216
217
218
219
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
220
221
222
223

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

224
225
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
226
    def set_inpaint(self):
227
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
228

comfyanonymous's avatar
comfyanonymous committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
252

comfyanonymous's avatar
comfyanonymous committed
253
class SD21UNCLIP(BaseModel):
254
255
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
256
257
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

258
259
260
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
261
262
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
263
        else:
comfyanonymous's avatar
comfyanonymous committed
264
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
265

266
267
268
269
270
271
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

272
class SDXLRefiner(BaseModel):
273
274
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
275
        self.embedder = Timestep(256)
276
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
277
278

    def encode_adm(self, **kwargs):
279
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
280
281
282
283
284
285
286
287
288
289
290
291
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
292
        out.append(self.embedder(torch.Tensor([width])))
293
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
294
        out.append(self.embedder(torch.Tensor([crop_w])))
295
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
296
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
297
298
299
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
300
301
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
302
        self.embedder = Timestep(256)
303
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
304
305

    def encode_adm(self, **kwargs):
306
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
307
308
309
310
311
312
313
314
315
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
316
        out.append(self.embedder(torch.Tensor([width])))
317
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
318
        out.append(self.embedder(torch.Tensor([crop_w])))
319
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
320
        out.append(self.embedder(torch.Tensor([target_width])))
321
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
322
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)