model_base.py 10.4 KB
Newer Older
comfyanonymous's avatar
comfyanonymous committed
1
2
3
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
comfyanonymous's avatar
comfyanonymous committed
5
import comfy.model_management
6
import comfy.conds
7
from enum import Enum
8
from . import utils
comfyanonymous's avatar
comfyanonymous committed
9

10
11
12
13
class ModelType(Enum):
    EPS = 1
    V_PREDICTION = 2

comfyanonymous's avatar
comfyanonymous committed
14

15
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
16

comfyanonymous's avatar
comfyanonymous committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def model_sampling(model_config, model_type):
    if model_type == ModelType.EPS:
        c = EPS
    elif model_type == ModelType.V_PREDICTION:
        c = V_PREDICTION

    s = ModelSamplingDiscrete

    class ModelSampling(s, c):
        pass

    return ModelSampling(model_config)


comfyanonymous's avatar
comfyanonymous committed
31
class BaseModel(torch.nn.Module):
32
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
comfyanonymous's avatar
comfyanonymous committed
33
34
        super().__init__()

35
36
        unet_config = model_config.unet_config
        self.latent_format = model_config.latent_format
37
        self.model_config = model_config
comfyanonymous's avatar
comfyanonymous committed
38

39
40
        if not unet_config.get("disable_unet_model_creation", False):
            self.diffusion_model = UNetModel(**unet_config, device=device)
41
        self.model_type = model_type
comfyanonymous's avatar
comfyanonymous committed
42
43
        self.model_sampling = model_sampling(model_config, model_type)

44
45
        self.adm_channels = unet_config.get("adm_in_channels", None)
        if self.adm_channels is None:
comfyanonymous's avatar
comfyanonymous committed
46
            self.adm_channels = 0
47
        self.inpaint_model = False
48
        print("model_type", model_type.name)
comfyanonymous's avatar
comfyanonymous committed
49
50
        print("adm", self.adm_channels)

51
    def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
comfyanonymous's avatar
comfyanonymous committed
52
53
        sigma = t
        xc = self.model_sampling.calculate_input(sigma, x)
comfyanonymous's avatar
comfyanonymous committed
54
        if c_concat is not None:
comfyanonymous's avatar
comfyanonymous committed
55
56
            xc = torch.cat([xc] + [c_concat], dim=1)

57
        context = c_crossattn
58
59
        dtype = self.get_dtype()
        xc = xc.to(dtype)
60
        t = self.model_sampling.timestep(t).float()
61
        context = context.to(dtype)
62
63
64
        extra_conds = {}
        for o in kwargs:
            extra_conds[o] = kwargs[o].to(dtype)
comfyanonymous's avatar
comfyanonymous committed
65
66
        model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
        return self.model_sampling.calculate_denoised(sigma, model_output, x)
comfyanonymous's avatar
comfyanonymous committed
67
68
69
70
71
72
73

    def get_dtype(self):
        return self.diffusion_model.dtype

    def is_adm(self):
        return self.adm_channels > 0

74
75
76
    def encode_adm(self, **kwargs):
        return None

77
78
    def extra_conds(self, **kwargs):
        out = {}
79
80
81
82
83
84
        if self.inpaint_model:
            concat_keys = ("mask", "masked_image")
            cond_concat = []
            denoise_mask = kwargs.get("denoise_mask", None)
            latent_image = kwargs.get("latent_image", None)
            noise = kwargs.get("noise", None)
85
            device = kwargs["device"]
86
87
88
89
90
91
92
93
94
95
96
97
98

            def blank_inpaint_image_like(latent_image):
                blank_image = torch.ones_like(latent_image)
                # these are the values for "zero" in pixel space translated to latent space
                blank_image[:,0] *= 0.8223
                blank_image[:,1] *= -0.6876
                blank_image[:,2] *= 0.6364
                blank_image[:,3] *= 0.1380
                return blank_image

            for ck in concat_keys:
                if denoise_mask is not None:
                    if ck == "mask":
99
                        cond_concat.append(denoise_mask[:,:1].to(device))
100
                    elif ck == "masked_image":
101
                        cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
102
103
104
105
106
                else:
                    if ck == "mask":
                        cond_concat.append(torch.ones_like(noise)[:,:1])
                    elif ck == "masked_image":
                        cond_concat.append(blank_inpaint_image_like(noise))
107
108
109
110
            data = torch.cat(cond_concat, dim=1)
            out['c_concat'] = comfy.conds.CONDNoiseShape(data)
        adm = self.encode_adm(**kwargs)
        if adm is not None:
111
            out['y'] = comfy.conds.CONDRegular(adm)
112
        return out
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def load_model_weights(self, sd, unet_prefix=""):
        to_load = {}
        keys = list(sd.keys())
        for k in keys:
            if k.startswith(unet_prefix):
                to_load[k[len(unet_prefix):]] = sd.pop(k)

        m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
        if len(m) > 0:
            print("unet missing:", m)

        if len(u) > 0:
            print("unet unexpected:", u)
        del to_load
        return self

130
131
132
133
134
135
    def process_latent_in(self, latent):
        return self.latent_format.process_in(latent)

    def process_latent_out(self, latent):
        return self.latent_format.process_out(latent)

136
137
    def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
        clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
comfyanonymous's avatar
comfyanonymous committed
138
139
140
141
142
        unet_sd = self.diffusion_model.state_dict()
        unet_state_dict = {}
        for k in unet_sd:
            unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)

143
144
145
146
147
        unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
        vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
        if self.get_dtype() == torch.float16:
            clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
            vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
148
149
150
151

        if self.model_type == ModelType.V_PREDICTION:
            unet_state_dict["v_pred"] = torch.tensor([])

152
153
        return {**unet_state_dict, **vae_state_dict, **clip_state_dict}

comfyanonymous's avatar
comfyanonymous committed
154
    def set_inpaint(self):
155
        self.inpaint_model = True
comfyanonymous's avatar
comfyanonymous committed
156

comfyanonymous's avatar
comfyanonymous committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
    adm_inputs = []
    weights = []
    noise_aug = []
    for unclip_cond in unclip_conditioning:
        for adm_cond in unclip_cond["clip_vision_output"].image_embeds:
            weight = unclip_cond["strength"]
            noise_augment = unclip_cond["noise_augmentation"]
            noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
            c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device))
            adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight
            weights.append(weight)
            noise_aug.append(noise_augment)
            adm_inputs.append(adm_out)

    if len(noise_aug) > 1:
        adm_out = torch.stack(adm_inputs).sum(0)
        noise_augment = noise_augment_merge
        noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment)
        c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device))
        adm_out = torch.cat((c_adm, noise_level_emb), 1)

    return adm_out
180

comfyanonymous's avatar
comfyanonymous committed
181
class SD21UNCLIP(BaseModel):
182
183
    def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None):
        super().__init__(model_config, model_type, device=device)
comfyanonymous's avatar
comfyanonymous committed
184
185
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config)

186
187
188
    def encode_adm(self, **kwargs):
        unclip_conditioning = kwargs.get("unclip_conditioning", None)
        device = kwargs["device"]
comfyanonymous's avatar
comfyanonymous committed
189
190
        if unclip_conditioning is None:
            return torch.zeros((1, self.adm_channels))
191
        else:
comfyanonymous's avatar
comfyanonymous committed
192
            return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05))
193

194
195
196
197
198
199
def sdxl_pooled(args, noise_augmentor):
    if "unclip_conditioning" in args:
        return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280]
    else:
        return args["pooled_output"]

200
class SDXLRefiner(BaseModel):
201
202
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
203
        self.embedder = Timestep(256)
204
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
205
206

    def encode_adm(self, **kwargs):
207
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
208
209
210
211
212
213
214
215
216
217
218
219
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)

        if kwargs.get("prompt_type", "") == "negative":
            aesthetic_score = kwargs.get("aesthetic_score", 2.5)
        else:
            aesthetic_score = kwargs.get("aesthetic_score", 6)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
220
        out.append(self.embedder(torch.Tensor([width])))
221
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
222
        out.append(self.embedder(torch.Tensor([crop_w])))
223
        out.append(self.embedder(torch.Tensor([aesthetic_score])))
224
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
225
226
227
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

class SDXL(BaseModel):
228
229
    def __init__(self, model_config, model_type=ModelType.EPS, device=None):
        super().__init__(model_config, model_type, device=device)
230
        self.embedder = Timestep(256)
231
        self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280})
232
233

    def encode_adm(self, **kwargs):
234
        clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor)
235
236
237
238
239
240
241
242
243
        width = kwargs.get("width", 768)
        height = kwargs.get("height", 768)
        crop_w = kwargs.get("crop_w", 0)
        crop_h = kwargs.get("crop_h", 0)
        target_width = kwargs.get("target_width", width)
        target_height = kwargs.get("target_height", height)

        out = []
        out.append(self.embedder(torch.Tensor([height])))
comfyanonymous's avatar
comfyanonymous committed
244
        out.append(self.embedder(torch.Tensor([width])))
245
        out.append(self.embedder(torch.Tensor([crop_h])))
comfyanonymous's avatar
comfyanonymous committed
246
        out.append(self.embedder(torch.Tensor([crop_w])))
247
        out.append(self.embedder(torch.Tensor([target_height])))
comfyanonymous's avatar
comfyanonymous committed
248
        out.append(self.embedder(torch.Tensor([target_width])))
249
        flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
250
        return torch.cat((clip_pooled.to(flat.device), flat), dim=1)