supported_models.py 17 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
from . import model_base
from . import utils

from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip

from . import supported_models_base
10
from . import latent_formats
11

12
13
from . import diffusers_convert

14
15
16
17
18
19
class SD15(supported_models_base.BASE):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
comfyanonymous's avatar
comfyanonymous committed
20
        "use_temporal_attention": False,
21
22
23
24
25
26
27
    }

    unet_extra_config = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

28
    latent_format = latent_formats.SD15
29
30
31
32
33
34
35
36
37
38
39
40
41

    def process_clip_state_dict(self, state_dict):
        k = list(state_dict.keys())
        for x in k:
            if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
                y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
                state_dict[y] = state_dict.pop(x)

        if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
            ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
            if ids.dtype == torch.float32:
                state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()

42
        replace_prefix = {}
43
44
        replace_prefix["cond_stage_model."] = "clip_l."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
45
46
        return state_dict

47
    def process_clip_state_dict_for_saving(self, state_dict):
48
49
50
51
52
        pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
        for p in pop_keys:
            if p in state_dict:
                state_dict.pop(p)

53
54
55
        replace_prefix = {"clip_l.": "cond_stage_model."}
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)

56
57
58
59
60
61
62
63
64
    def clip_target(self):
        return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)

class SD20(supported_models_base.BASE):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": None,
comfyanonymous's avatar
comfyanonymous committed
65
        "use_temporal_attention": False,
66
67
    }

68
69
70
71
72
73
    unet_extra_config = {
        "num_heads": -1,
        "num_head_channels": 64,
        "attn_precision": torch.float32,
    }

74
    latent_format = latent_formats.SD15
75

76
    def model_type(self, state_dict, prefix=""):
77
        if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
78
            k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
comfyanonymous's avatar
comfyanonymous committed
79
80
            out = state_dict.get(k, None)
            if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
81
82
                return model_base.ModelType.V_PREDICTION
        return model_base.ModelType.EPS
83
84

    def process_clip_state_dict(self, state_dict):
85
        replace_prefix = {}
86
87
88
        replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
        replace_prefix["cond_stage_model.model."] = "clip_h."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
89
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
90
91
        return state_dict

92
93
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
94
        replace_prefix["clip_h"] = "cond_stage_model.model"
95
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
96
97
98
        state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
        return state_dict

99
100
101
102
103
104
105
106
107
    def clip_target(self):
        return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

class SD21UnclipL(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 1536,
comfyanonymous's avatar
comfyanonymous committed
108
        "use_temporal_attention": False,
109
110
111
112
113
114
115
116
117
118
119
120
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}


class SD21UnclipH(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 2048,
comfyanonymous's avatar
comfyanonymous committed
121
        "use_temporal_attention": False,
122
123
124
125
126
127
128
129
130
131
132
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}

class SDXLRefiner(supported_models_base.BASE):
    unet_config = {
        "model_channels": 384,
        "use_linear_in_transformer": True,
        "context_dim": 1280,
        "adm_in_channels": 2560,
133
        "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
comfyanonymous's avatar
comfyanonymous committed
134
        "use_temporal_attention": False,
135
136
    }

137
    latent_format = latent_formats.SDXL
138

139
140
    def get_model(self, state_dict, prefix="", device=None):
        return model_base.SDXLRefiner(self, device=device)
141
142
143
144

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}
145
146
        replace_prefix["conditioner.embedders.0.model."] = "clip_g."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
147

148
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
149
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
150
151
        return state_dict

152
153
154
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
155
156
        if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
            state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
157
        replace_prefix["clip_g"] = "conditioner.embedders.0.model"
158
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
159
160
        return state_dict_g

161
162
163
164
165
166
167
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)

class SDXL(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
168
        "transformer_depth": [0, 0, 2, 2, 10, 10],
169
        "context_dim": 2048,
comfyanonymous's avatar
comfyanonymous committed
170
171
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
172
173
    }

174
    latent_format = latent_formats.SDXL
175

176
    def model_type(self, state_dict, prefix=""):
177
178
179
180
181
182
        if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
            self.latent_format = latent_formats.SDXL_Playground_2_5()
            self.sampling_settings["sigma_data"] = 0.5
            self.sampling_settings["sigma_max"] = 80.0
            self.sampling_settings["sigma_min"] = 0.002
            return model_base.ModelType.EDM
comfyanonymous's avatar
comfyanonymous committed
183
184
185
186
187
        elif "edm_vpred.sigma_max" in state_dict:
            self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item())
            if "edm_vpred.sigma_min" in state_dict:
                self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
            return model_base.ModelType.V_PREDICTION_EDM
188
        elif "v_pred" in state_dict:
189
190
191
192
            return model_base.ModelType.V_PREDICTION
        else:
            return model_base.ModelType.EPS

193
    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
194
195
196
197
        out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
        if self.inpaint_model():
            out.set_inpaint()
        return out
198
199
200
201
202

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}

203
204
205
206
        replace_prefix["conditioner.embedders.0.transformer.text_model"] = "clip_l.transformer.text_model"
        replace_prefix["conditioner.embedders.1.model."] = "clip_g."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)

207
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
208
        state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
209
210
        return state_dict

211
212
213
214
215
216
217
218
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        keys_to_replace = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
        for k in state_dict:
            if k.startswith("clip_l"):
                state_dict_g[k] = state_dict[k]

219
220
221
222
223
224
        state_dict_g["clip_l.transformer.text_model.embeddings.position_ids"] = torch.arange(77).expand((1, -1))
        pop_keys = ["clip_l.transformer.text_projection.weight", "clip_l.logit_scale"]
        for p in pop_keys:
            if p in state_dict_g:
                state_dict_g.pop(p)

225
226
        replace_prefix["clip_g"] = "conditioner.embedders.1.model"
        replace_prefix["clip_l"] = "conditioner.embedders.0"
227
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
228
229
        return state_dict_g

230
231
232
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

233
234
235
236
237
238
class SSD1B(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 2, 2, 4, 4],
        "context_dim": 2048,
comfyanonymous's avatar
comfyanonymous committed
239
240
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
241
242
    }

comfyanonymous's avatar
comfyanonymous committed
243
244
245
246
247
248
249
250
251
252
class Segmind_Vega(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 1, 1, 2, 2],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
    }

comfyanonymous's avatar
comfyanonymous committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class KOALA_700M(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 2, 5],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
    }

class KOALA_1B(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 2, 6],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
    }

comfyanonymous's avatar
comfyanonymous committed
273
274
275
276
277
278
279
280
281
282
283
284
class SVD_img2vid(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 768,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }

285
286
287
288
289
290
    unet_extra_config = {
        "num_heads": -1,
        "num_head_channels": 64,
        "attn_precision": torch.float32,
    }

comfyanonymous's avatar
comfyanonymous committed
291
292
293
294
295
296
297
298
299
300
301
302
    clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."

    latent_format = latent_formats.SD15

    sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SVD_img2vid(self, device=device)
        return out

    def clip_target(self):
        return None
303

comfyanonymous's avatar
comfyanonymous committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
class SV3D_u(SVD_img2vid):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 256,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }

    vae_key_prefix = ["conditioner.embedders.1.encoder."]

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SV3D_u(self, device=device)
        return out

class SV3D_p(SV3D_u):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 1280,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }


    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SV3D_p(self, device=device)
        return out

339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
class Stable_Zero123(supported_models_base.BASE):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
        "use_temporal_attention": False,
        "in_channels": 8,
    }

    unet_extra_config = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

comfyanonymous's avatar
comfyanonymous committed
354
355
356
357
358
    required_keys = {
        "cc_projection.weight": None,
        "cc_projection.bias": None,
    }

359
360
361
362
363
364
365
366
367
368
369
    clip_vision_prefix = "cond_stage_model.model.visual."

    latent_format = latent_formats.SD15

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
        return out

    def clip_target(self):
        return None

370
371
372
373
374
375
376
377
378
379
380
381
class SD_X4Upscaler(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 256,
        'in_channels': 7,
        "use_linear_in_transformer": True,
        "adm_in_channels": None,
        "use_temporal_attention": False,
    }

    unet_extra_config = {
        "disable_self_attentions": [True, True, True, False],
382
        "num_classes": 1000,
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        "num_heads": 8,
        "num_head_channels": -1,
    }

    latent_format = latent_formats.SD_X4

    sampling_settings = {
        "linear_start": 0.0001,
        "linear_end": 0.02,
    }

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SD_X4Upscaler(self, device=device)
        return out
397

comfyanonymous's avatar
comfyanonymous committed
398
399
400
401
402
403
404
405
406
407
class Stable_Cascade_C(supported_models_base.BASE):
    unet_config = {
        "stable_cascade_stage": 'c',
    }

    unet_extra_config = {}

    latent_format = latent_formats.SC_Prior
    supported_inference_dtypes = [torch.bfloat16, torch.float32]

408
409
410
411
    sampling_settings = {
        "shift": 2.0,
    }

412
413
414
415
    vae_key_prefix = ["vae."]
    text_encoder_key_prefix = ["text_encoder."]
    clip_vision_prefix = "clip_l_vision."

comfyanonymous's avatar
comfyanonymous committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def process_unet_state_dict(self, state_dict):
        key_list = list(state_dict.keys())
        for y in ["weight", "bias"]:
            suffix = "in_proj_{}".format(y)
            keys = filter(lambda a: a.endswith(suffix), key_list)
            for k_from in keys:
                weights = state_dict.pop(k_from)
                prefix = k_from[:-(len(suffix) + 1)]
                shape_from = weights.shape[0] // 3
                for x in range(3):
                    p = ["to_q", "to_k", "to_v"]
                    k_to = "{}.{}.{}".format(prefix, p[x], y)
                    state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
        return state_dict

431
432
433
434
435
436
    def process_clip_state_dict(self, state_dict):
        state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
        if "clip_g.text_projection" in state_dict:
            state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
        return state_dict

comfyanonymous's avatar
comfyanonymous committed
437
438
439
440
441
    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.StableCascade_C(self, device=device)
        return out

    def clip_target(self):
442
        return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
comfyanonymous's avatar
comfyanonymous committed
443

comfyanonymous's avatar
comfyanonymous committed
444
445
446
447
448
449
450
451
452
453
class Stable_Cascade_B(Stable_Cascade_C):
    unet_config = {
        "stable_cascade_stage": 'b',
    }

    unet_extra_config = {}

    latent_format = latent_formats.SC_B
    supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]

454
455
456
457
    sampling_settings = {
        "shift": 1.0,
    }

458
459
    clip_vision_prefix = None

comfyanonymous's avatar
comfyanonymous committed
460
461
462
463
    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.StableCascade_B(self, device=device)
        return out

comfyanonymous's avatar
comfyanonymous committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
class SD15_instructpix2pix(SD15):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
        "use_temporal_attention": False,
        "in_channels": 8,
    }

    def get_model(self, state_dict, prefix="", device=None):
        return model_base.SD15_instructpix2pix(self, device=device)

class SDXL_instructpix2pix(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 2, 2, 10, 10],
        "context_dim": 2048,
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
        "in_channels": 8,
    }

    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
489
        return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
comfyanonymous's avatar
comfyanonymous committed
490
491

models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
comfyanonymous's avatar
comfyanonymous committed
492

comfyanonymous's avatar
comfyanonymous committed
493
models += [SVD_img2vid]