supported_models.py 8.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
from . import model_base
from . import utils

from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip

from . import supported_models_base
10
from . import latent_formats
11

12
13
from . import diffusers_convert

14
15
16
17
18
19
20
21
22
23
24
25
26
class SD15(supported_models_base.BASE):
    unet_config = {
        "context_dim": 768,
        "model_channels": 320,
        "use_linear_in_transformer": False,
        "adm_in_channels": None,
    }

    unet_extra_config = {
        "num_heads": 8,
        "num_head_channels": -1,
    }

27
    latent_format = latent_formats.SD15
28
29
30
31
32
33
34
35
36
37
38
39
40

    def process_clip_state_dict(self, state_dict):
        k = list(state_dict.keys())
        for x in k:
            if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
                y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
                state_dict[y] = state_dict.pop(x)

        if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
            ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
            if ids.dtype == torch.float32:
                state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()

41
42
43
        replace_prefix = {}
        replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
44
45
        return state_dict

46
47
48
49
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {"clip_l.": "cond_stage_model."}
        return utils.state_dict_prefix_replace(state_dict, replace_prefix)

50
51
52
53
54
55
56
57
58
59
60
    def clip_target(self):
        return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)

class SD20(supported_models_base.BASE):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": None,
    }

61
    latent_format = latent_formats.SD15
62

63
    def model_type(self, state_dict, prefix=""):
64
        if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
65
            k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
66
67
            out = state_dict[k]
            if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
68
69
                return model_base.ModelType.V_PREDICTION
        return model_base.ModelType.EPS
70
71

    def process_clip_state_dict(self, state_dict):
72
        state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
73
74
        return state_dict

75
76
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
77
        replace_prefix["clip_h"] = "cond_stage_model.model"
78
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
79
80
81
        state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
        return state_dict

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def clip_target(self):
        return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)

class SD21UnclipL(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 1536,
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}


class SD21UnclipH(SD20):
    unet_config = {
        "context_dim": 1024,
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "adm_in_channels": 2048,
    }

    clip_vision_prefix = "embedder.model.visual."
    noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}

class SDXLRefiner(supported_models_base.BASE):
    unet_config = {
        "model_channels": 384,
        "use_linear_in_transformer": True,
        "context_dim": 1280,
        "adm_in_channels": 2560,
114
        "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
115
116
    }

117
    latent_format = latent_formats.SDXL
118

119
120
    def get_model(self, state_dict, prefix="", device=None):
        return model_base.SDXLRefiner(self, device=device)
121
122
123
124
125
126
127

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}

        state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
        keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
128
        keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
129

130
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
131
132
        return state_dict

133
134
135
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
136
137
        if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
            state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
138
        replace_prefix["clip_g"] = "conditioner.embedders.0.model"
139
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
140
141
        return state_dict_g

142
143
144
145
146
147
148
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)

class SDXL(supported_models_base.BASE):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
149
        "transformer_depth": [0, 0, 2, 2, 10, 10],
150
151
152
153
        "context_dim": 2048,
        "adm_in_channels": 2816
    }

154
    latent_format = latent_formats.SDXL
155

156
157
158
159
160
161
    def model_type(self, state_dict, prefix=""):
        if "v_pred" in state_dict:
            return model_base.ModelType.V_PREDICTION
        else:
            return model_base.ModelType.EPS

162
    def get_model(self, state_dict, prefix="", device=None):
comfyanonymous's avatar
comfyanonymous committed
163
164
165
166
        out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
        if self.inpaint_model():
            out.set_inpaint()
        return out
167
168
169
170
171
172
173
174

    def process_clip_state_dict(self, state_dict):
        keys_to_replace = {}
        replace_prefix = {}

        replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
        state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
        keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
175
        keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
176
        keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
177

178
179
        state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
        state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
180
181
        return state_dict

182
183
184
185
    def process_clip_state_dict_for_saving(self, state_dict):
        replace_prefix = {}
        keys_to_replace = {}
        state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
186
187
        if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
            state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
188
189
190
191
192
193
        for k in state_dict:
            if k.startswith("clip_l"):
                state_dict_g[k] = state_dict[k]

        replace_prefix["clip_g"] = "conditioner.embedders.1.model"
        replace_prefix["clip_l"] = "conditioner.embedders.0"
194
        state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
195
196
        return state_dict_g

197
198
199
    def clip_target(self):
        return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)

200
201
202
203
204
205
206
207
208
class SSD1B(SDXL):
    unet_config = {
        "model_channels": 320,
        "use_linear_in_transformer": True,
        "transformer_depth": [0, 0, 2, 2, 4, 4],
        "context_dim": 2048,
        "adm_in_channels": 2816
    }

209

210
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]