supported_models.py 9.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
from . import model_base
from . import utils

from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip

from . import supported_models_base
10
from . import latent_formats
11

12
13
from . import diffusers_convert

14
15
16
17
18
19
class SD15(supported_models_base.BASE):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
comfyanonymous's avatar
comfyanonymous committed
20
        "use_temporal_attention": False,
21
22
23
24
25
26
27
    }

    unet_extra_config = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

28
    latent_format = latent_formats.SD15
29
30
31
32
33
34
35
36
37
38
39
40
41

    def process_clip_state_dict(self, state_dict):
        k = list(state_dict.keys())
        for x in k:
            if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
                y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
                state_dict[y] = state_dict.pop(x)

        if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
            ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
            if ids.dtype == torch.float32:
                state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()

42
43
44
        replace_prefix = {}
        replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
45
46
        return state_dict

47
48
49
50
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {"clip_l.": "cond_stage_model."}
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)

51
52
53
54
55
56
57
58
59
    def clip_target(self):
        return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)

class SD20(supported_models_base.BASE):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": None,
comfyanonymous's avatar
comfyanonymous committed
60
        "use_temporal_attention": False,
61
62
    }

63
    latent_format = latent_formats.SD15
64

65
    def model_type(self, state_dict, prefix=""):
66
        if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
67
            k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
68
69
            out = state_dict[k]
            if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
70
71
                return model_base.ModelType.V_PREDICTION
        return model_base.ModelType.EPS
72
73

    def process_clip_state_dict(self, state_dict):
74
75
76
77
        replace_prefix = {}
        replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)

78
        state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
79
80
        return state_dict

81
82
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
83
        replace_prefix["clip_h"] = "cond_stage_model.model"
84
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
85
86
87
        state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
        return state_dict

88
89
90
91
92
93
94
95
96
    def clip_target(self):
        return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

class SD21UnclipL(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 1536,
comfyanonymous's avatar
comfyanonymous committed
97
        "use_temporal_attention": False,
98
99
100
101
102
103
104
105
106
107
108
109
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}


class SD21UnclipH(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 2048,
comfyanonymous's avatar
comfyanonymous committed
110
        "use_temporal_attention": False,
111
112
113
114
115
116
117
118
119
120
121
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}

class SDXLRefiner(supported_models_base.BASE):
    unet_config = {
        "model_channels": 384,
        "use_linear_in_transformer": True,
        "context_dim": 1280,
        "adm_in_channels": 2560,
122
        "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
comfyanonymous's avatar
comfyanonymous committed
123
        "use_temporal_attention": False,
124
125
    }

126
    latent_format = latent_formats.SDXL
127

128
129
    def get_model(self, state_dict, prefix="", device=None):
        return model_base.SDXLRefiner(self, device=device)
130
131
132
133
134
135
136

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}

        state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
        keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
137
        keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
138

139
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
140
141
        return state_dict

142
143
144
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
145
146
        if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
            state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
147
        replace_prefix["clip_g"] = "conditioner.embedders.0.model"
148
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
149
150
        return state_dict_g

151
152
153
154
155
156
157
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)

class SDXL(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
158
        "transformer_depth": [0, 0, 2, 2, 10, 10],
159
        "context_dim": 2048,
comfyanonymous's avatar
comfyanonymous committed
160
161
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
162
163
    }

164
    latent_format = latent_formats.SDXL
165

166
167
168
169
170
171
    def model_type(self, state_dict, prefix=""):
        if "v_pred" in state_dict:
            return model_base.ModelType.V_PREDICTION
        else:
            return model_base.ModelType.EPS

172
    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
173
174
175
176
        out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
        if self.inpaint_model():
            out.set_inpaint()
        return out
177
178
179
180
181
182
183
184

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}

        replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
        state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
        keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
185
        keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
186
        keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
187

188
189
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
190
191
        return state_dict

192
193
194
195
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        keys_to_replace = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
196
197
        if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
            state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
198
199
200
201
202
203
        for k in state_dict:
            if k.startswith("clip_l"):
                state_dict_g[k] = state_dict[k]

        replace_prefix["clip_g"] = "conditioner.embedders.1.model"
        replace_prefix["clip_l"] = "conditioner.embedders.0"
204
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
205
206
        return state_dict_g

207
208
209
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

210
211
212
213
214
215
class SSD1B(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 2, 2, 4, 4],
        "context_dim": 2048,
comfyanonymous's avatar
comfyanonymous committed
216
217
        "adm_in_channels": 2816,
        "use_temporal_attention": False,
218
219
    }

comfyanonymous's avatar
comfyanonymous committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class SVD_img2vid(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "in_channels": 8,
        "use_linear_in_transformer": True,
        "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
        "context_dim": 1024,
        "adm_in_channels": 768,
        "use_temporal_attention": True,
        "use_temporal_resblock": True
    }

    clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."

    latent_format = latent_formats.SD15

    sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}

    def get_model(self, state_dict, prefix="", device=None):
        out = model_base.SVD_img2vid(self, device=device)
        return out

    def clip_target(self):
        return None
244

245
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
comfyanonymous's avatar
comfyanonymous committed
246
models += [SVD_img2vid]