supported_models.py 16.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
from . import model_base
from . import utils

from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip

from . import supported_models_base
10
from . import latent_formats
11

12
13
from . import diffusers_convert

14
15
16
17
18
19
class SD15(supported_models_base.BASE):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
comfyanonymous's avatar
comfyanonymous committed
20
        "use_temporal_attention": False,
21
22
23
24
25
26
27
    }

    unet_extra_config = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

28
    latent_format = latent_formats.SD15
29
30
31
32
33
34
35
36
37
38
39
40
41

    def process_clip_state_dict(self, state_dict):
        k = list(state_dict.keys())
        for x in k:
            if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
                y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
                state_dict[y] = state_dict.pop(x)

        if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
            ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
            if ids.dtype == torch.float32:
                state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()

42
        replace_prefix = {}
43
44
        replace_prefix["cond_stage_model."] = "clip_l."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
45
46
        return state_dict

47
    def process_clip_state_dict_for_saving(self, state_dict):
48
49
50
51
52
        pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
        for p in pop_keys:
            if p in state_dict:
                state_dict.pop(p)

53
54
55
        replace_prefix = {"clip_l.": "cond_stage_model."}
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)

56
57
58
59
60
61
62
63
64
    def clip_target(self):
        return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)

class SD20(supported_models_base.BASE):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": None,
comfyanonymous's avatar
comfyanonymous committed
65
        "use_temporal_attention": False,
66
67
    }

68
    latent_format = latent_formats.SD15
69

70
    def model_type(self, state_dict, prefix=""):
71
        if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
72
            k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
comfyanonymous's avatar
comfyanonymous committed
73
74
            out = state_dict.get(k, None)
            if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
75
76
                return model_base.ModelType.V_PREDICTION
        return model_base.ModelType.EPS
77
78

    def process_clip_state_dict(self, state_dict):
79
        replace_prefix = {}
80
81
82
        replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
        replace_prefix["cond_stage_model.model."] = "clip_h."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
83
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
84
85
        return state_dict

86
87
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
88
        replace_prefix["clip_h"] = "cond_stage_model.model"
89
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
90
91
92
        state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
        return state_dict

93
94
95
96
97
98
99
100
101
    def clip_target(self):
        return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

class SD21UnclipL(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 1536,
comfyanonymous's avatar
comfyanonymous committed
102
        "use_temporal_attention": False,
103
104
105
106
107
108
109
110
111
112
113
114
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}


class SD21UnclipH(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 2048,
comfyanonymous's avatar
comfyanonymous committed
115
        "use_temporal_attention": False,
116
117
118
119
120
121
122
123
124
125
126
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}

class SDXLRefiner(supported_models_base.BASE):
    unet_config = {
        "model_channels": 384,
        "use_linear_in_transformer": True,
        "context_dim": 1280,
        "adm_in_channels": 2560,
127
        "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
comfyanonymous's avatar
comfyanonymous committed
128
        "use_temporal_attention": False,
129
130
    }

131
    latent_format = latent_formats.SDXL
132

133
134
    def get_model(self, state_dict, prefix="", device=None):
        return model_base.SDXLRefiner(self, device=device)
135
136
137
138

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}
139
140
        replace_prefix["conditioner.embedders.0.model."] = "clip_g."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
141

142
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
143
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
144
145
        return state_dict

146
147
148
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
149
150
        if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
            state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
151
        replace_prefix["clip_g"] = "conditioner.embedders.0.model"
152
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
153
154
        return state_dict_g

155
156
157
158
159
160
161
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)

class SDXL(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
162
        "transformer_depth": [0, 0, 2, 2, 10, 10],
163
        "context_dim": 2048,
comfyanonymous's avatar
comfyanonymous committed
164
165
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
166
167
    }

168
    latent_format = latent_formats.SDXL
169

170
    def model_type(self, state_dict, prefix=""):
171
172
173
174
175
176
        if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
            self.latent_format = latent_formats.SDXL_Playground_2_5()
            self.sampling_settings["sigma_data"] = 0.5
            self.sampling_settings["sigma_max"] = 80.0
            self.sampling_settings["sigma_min"] = 0.002
            return model_base.ModelType.EDM
comfyanonymous's avatar
comfyanonymous committed
177
178
179
180
181
        elif "edm_vpred.sigma_max" in state_dict:
            self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
            if "edm_vpred.sigma_min" in state_dict:
                self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
            return model_base.ModelType.V_PREDICTION_EDM
182
        elif "v_pred" in state_dict:
183
184
185
186
            return model_base.ModelType.V_PREDICTION
        else:
            return model_base.ModelType.EPS

187
    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
188
189
190
191
        out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
        if self.inpaint_model():
            out.set_inpaint()
        return out
192
193
194
195
196

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}

197
198
199
200
        replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
        replace_prefix["conditioner.embedders.1.model."] = "clip_g."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)

201
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
202
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
203
204
        return state_dict

205
206
207
208
209
210
211
212
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        keys_to_replace = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
        for k in state_dict:
            if k.startswith("clip_l"):
                state_dict_g[k] = state_dict[k]

213
214
215
216
217
218
        state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
        pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
        for p in pop_keys:
            if p in state_dict_g:
                state_dict_g.pop(p)

219
220
        replace_prefix["clip_g"] = "conditioner.embedders.1.model"
        replace_prefix["clip_l"] = "conditioner.embedders.0"
221
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
222
223
        return state_dict_g

224
225
226
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

227
228
229
230
231
232
class SSD1B(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 2, 2, 4, 4],
        "context_dim": 2048,
comfyanonymous's avatar
comfyanonymous committed
233
234
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
235
236
    }

comfyanonymous's avatar
comfyanonymous committed
237
238
239
240
241
242
243
244
245
246
class Segmind_Vega(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 1, 1, 2, 2],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
    }

comfyanonymous's avatar
comfyanonymous committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class KOALA_700M(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 2, 5],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
    }

class KOALA_1B(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 2, 6],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
    }

comfyanonymous's avatar
comfyanonymous committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class SVD_img2vid(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 768,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }

    clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."

    latent_format = latent_formats.SD15

    sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SVD_img2vid(self, device=device)
        return out

    def clip_target(self):
        return None
291

comfyanonymous's avatar
comfyanonymous committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
class SV3D_u(SVD_img2vid):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 256,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }

    vae_key_prefix = ["conditioner.embedders.1.encoder."]

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SV3D_u(self, device=device)
        return out

class SV3D_p(SV3D_u):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 1280,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }


    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SV3D_p(self, device=device)
        return out

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
class Stable_Zero123(supported_models_base.BASE):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
        "use_temporal_attention": False,
        "in_channels": 8,
    }

    unet_extra_config = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

comfyanonymous's avatar
comfyanonymous committed
342
343
344
345
346
    required_keys = {
        "cc_projection.weight": None,
        "cc_projection.bias": None,
    }

347
348
349
350
351
352
353
354
355
356
357
    clip_vision_prefix = "cond_stage_model.model.visual."

    latent_format = latent_formats.SD15

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
        return out

    def clip_target(self):
        return None

358
359
360
361
362
363
364
365
366
367
368
369
class SD_X4Upscaler(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 256,
        'in_channels': 7,
        "use_linear_in_transformer": True,
        "adm_in_channels": None,
        "use_temporal_attention": False,
    }

    unet_extra_config = {
        "disable_self_attentions": [True, True, True, False],
370
        "num_classes": 1000,
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        "num_heads": 8,
        "num_head_channels": -1,
    }

    latent_format = latent_formats.SD_X4

    sampling_settings = {
        "linear_start": 0.0001,
        "linear_end": 0.02,
    }

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SD_X4Upscaler(self, device=device)
        return out
385

comfyanonymous's avatar
comfyanonymous committed
386
387
388
389
390
391
392
393
394
395
class Stable_Cascade_C(supported_models_base.BASE):
    unet_config = {
        "stable_cascade_stage": 'c',
    }

    unet_extra_config = {}

    latent_format = latent_formats.SC_Prior
    supported_inference_dtypes = [torch.bfloat16, torch.float32]

396
397
398
399
    sampling_settings = {
        "shift": 2.0,
    }

400
401
402
403
    vae_key_prefix = ["vae."]
    text_encoder_key_prefix = ["text_encoder."]
    clip_vision_prefix = "clip_l_vision."

comfyanonymous's avatar
comfyanonymous committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def process_unet_state_dict(self, state_dict):
        key_list = list(state_dict.keys())
        for y in ["weight", "bias"]:
            suffix = "in_proj_{}".format(y)
            keys = filter(lambda a: a.endswith(suffix), key_list)
            for k_from in keys:
                weights = state_dict.pop(k_from)
                prefix = k_from[:-(len(suffix) + 1)]
                shape_from = weights.shape[0] // 3
                for x in range(3):
                    p = ["to_q", "to_k", "to_v"]
                    k_to = "{}.{}.{}".format(prefix, p[x], y)
                    state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
        return state_dict

419
420
421
422
423
424
    def process_clip_state_dict(self, state_dict):
        state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
        if "clip_g.text_projection" in state_dict:
            state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
        return state_dict

comfyanonymous's avatar
comfyanonymous committed
425
426
427
428
429
    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.StableCascade_C(self, device=device)
        return out

    def clip_target(self):
430
        return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
comfyanonymous's avatar
comfyanonymous committed
431

comfyanonymous's avatar
comfyanonymous committed
432
433
434
435
436
437
438
439
440
441
class Stable_Cascade_B(Stable_Cascade_C):
    unet_config = {
        "stable_cascade_stage": 'b',
    }

    unet_extra_config = {}

    latent_format = latent_formats.SC_B
    supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]

442
443
444
445
    sampling_settings = {
        "shift": 1.0,
    }

446
447
    clip_vision_prefix = None

comfyanonymous's avatar
comfyanonymous committed
448
449
450
451
    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.StableCascade_B(self, device=device)
        return out

comfyanonymous's avatar
comfyanonymous committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
class SD15_instructpix2pix(SD15):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
        "use_temporal_attention": False,
        "in_channels": 8,
    }

    def get_model(self, state_dict, prefix="", device=None):
        return model_base.SD15_instructpix2pix(self, device=device)

class SDXL_instructpix2pix(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 2, 2, 10, 10],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
        "in_channels": 8,
    }

    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
477
        return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
comfyanonymous's avatar
comfyanonymous committed
478
479

models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
comfyanonymous's avatar
comfyanonymous committed
480

comfyanonymous's avatar
comfyanonymous committed
481
models += [SVD_img2vid]