model_base.py 10.5 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
5
import comfy.model_management
6
import comfy.conds
7
from enum import Enum
8
from . import utils
comfyanonymous's avatar
comfyanonymous committed
9

10
11
12
13
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2

comfyanonymous's avatar
comfyanonymous committed
14

15
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
16

comfyanonymous's avatar
comfyanonymous committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def model_sampling(model_config, model_type):
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION

    s = ModelSamplingDiscrete

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
31
class BaseModel(torch.nn.Module):
32
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
33
34
        super().__init__()

35
36
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
37
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
38

39
40
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
41
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
42
43
        self.model_sampling = model_sampling(model_config, model_type)

44
45
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
46
            self.adm_channels = 0
47
        self.inpaint_model = False
48
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
49
50
        print("adm", self.adm_channels)

51
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
52
53
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
54
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
55
56
            xc = torch.cat([xc] + [c_concat], dim=1)

57
        context = c_crossattn
58
59
        dtype = self.get_dtype()
        xc = xc.to(dtype)
60
        t = self.model_sampling.timestep(t).float()
61
        context = context.to(dtype)
62
63
        extra_conds = {}
        for o in kwargs:
64
65
66
67
            extra = kwargs[o]
            if hasattr(extra, "to"):
                extra = extra.to(dtype)
            extra_conds[o] = extra
comfyanonymous's avatar
comfyanonymous committed
68
69
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
70
71
72
73
74
75
76

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

77
78
79
    def encode_adm(self, **kwargs):
        return None

80
81
    def extra_conds(self, **kwargs):
        out = {}
82
83
84
85
86
87
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
88
            device = kwargs["device"]
89
90
91
92
93
94
95
96
97
98
99
100
101

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
102
                        cond_concat.append(denoise_mask[:,:1].to(device))
103
                    elif ck == "masked_image":
104
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
105
106
107
108
109
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
110
111
112
113
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
114
            out['y'] = comfy.conds.CONDRegular(adm)
115
        return out
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

133
134
135
136
137
138
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

139
140
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
141
142
143
144
145
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

146
147
148
149
150
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
151
152
153
154

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

155
156
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
157
    def set_inpaint(self):
158
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
159

comfyanonymous's avatar
comfyanonymous committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
183

comfyanonymous's avatar
comfyanonymous committed
184
class SD21UNCLIP(BaseModel):
185
186
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
187
188
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

189
190
191
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
192
193
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
194
        else:
comfyanonymous's avatar
comfyanonymous committed
195
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
196

197
198
199
200
201
202
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

203
class SDXLRefiner(BaseModel):
204
205
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
206
        self.embedder = Timestep(256)
207
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
208
209

    def encode_adm(self, **kwargs):
210
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
211
212
213
214
215
216
217
218
219
220
221
222
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
223
        out.append(self.embedder(torch.Tensor([width])))
224
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
225
        out.append(self.embedder(torch.Tensor([crop_w])))
226
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
227
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
228
229
230
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
231
232
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
233
        self.embedder = Timestep(256)
234
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
235
236

    def encode_adm(self, **kwargs):
237
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
238
239
240
241
242
243
244
245
246
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
247
        out.append(self.embedder(torch.Tensor([width])))
248
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
249
        out.append(self.embedder(torch.Tensor([crop_w])))
250
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
251
        out.append(self.embedder(torch.Tensor([target_width])))
252
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
253
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)