convert_stable_cascade.py 8.19 KB
Newer Older
1
2
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
3
from contextlib import nullcontext
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

import torch
from safetensors.torch import load_file
from transformers import (
    AutoTokenizer,
    CLIPConfig,
    CLIPImageProcessor,
    CLIPTextModelWithProjection,
    CLIPVisionModelWithProjection,
)

from diffusers import (
    DDPMWuerstchenScheduler,
    StableCascadeCombinedPipeline,
    StableCascadeDecoderPipeline,
    StableCascadePriorPipeline,
)
21
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
22
23
24
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
25
from diffusers.utils import is_accelerate_available
26
27


28
29
30
if is_accelerate_available():
    from accelerate import init_empty_weights

31
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
32
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
33
34
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
35
36
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
37
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
parser.add_argument(
    "--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
)
parser.add_argument(
    "--decoder_output_path",
    type=str,
    default="stable-cascade-decoder",
    help="Hub organization to save the pipelines to",
)
parser.add_argument(
    "--combined_output_path",
    type=str,
    default="stable-cascade-combined",
    help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
54
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
55
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")
56
57

args = parser.parse_args()
58
59
60
61
62
63

if args.skip_stage_b and args.skip_stage_c:
    raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
    raise ValueError("Cannot skip stages when creating a combined pipeline")

64
65
66
model_path = args.model_path

device = "cpu"
67
68
69
70
if args.variant == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"

# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
    "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")

# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")

# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
90
ctx = init_empty_weights if is_accelerate_available() else nullcontext
91

92
93
94
95
if not args.skip_stage_c:
    # Prior
    if args.use_safetensors:
        prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
96
    else:
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)

    prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)

    with ctx():
        prior_model = StableCascadeUNet(
            in_channels=16,
            out_channels=16,
            timestep_ratio_embedding_dim=64,
            patch_size=1,
            conditioning_dim=2048,
            block_out_channels=[2048, 2048],
            num_attention_heads=[32, 32],
            down_num_layers_per_block=[8, 24],
            up_num_layers_per_block=[24, 8],
            down_blocks_repeat_mappers=[1, 1],
            up_blocks_repeat_mappers=[1, 1],
            block_types_per_layer=[
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
            ],
            clip_text_in_channels=1280,
            clip_text_pooled_in_channels=1280,
            clip_image_in_channels=768,
            clip_seq=4,
            kernel_size=3,
            dropout=[0.1, 0.1],
            self_attn=True,
            timestep_conditioning_type=["sca", "crp"],
            switch_level=[False],
        )
    if is_accelerate_available():
        load_model_dict_into_meta(prior_model, prior_state_dict)
    else:
        prior_model.load_state_dict(prior_state_dict)

    # Prior pipeline
    prior_pipeline = StableCascadePriorPipeline(
        prior=prior_model,
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        image_encoder=image_encoder,
        scheduler=scheduler,
        feature_extractor=feature_extractor,
    )
    prior_pipeline.to(dtype).save_pretrained(
        args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
144
145
    )

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
if not args.skip_stage_b:
    # Decoder
    if args.use_safetensors:
        decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
    else:
        decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)

    decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
    with ctx():
        decoder = StableCascadeUNet(
            in_channels=4,
            out_channels=4,
            timestep_ratio_embedding_dim=64,
            patch_size=2,
            conditioning_dim=1280,
            block_out_channels=[320, 640, 1280, 1280],
            down_num_layers_per_block=[2, 6, 28, 6],
            up_num_layers_per_block=[6, 28, 6, 2],
            down_blocks_repeat_mappers=[1, 1, 1, 1],
            up_blocks_repeat_mappers=[3, 3, 2, 2],
            num_attention_heads=[0, 0, 20, 20],
            block_types_per_layer=[
                ["SDCascadeResBlock", "SDCascadeTimestepBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
                ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
            ],
            clip_text_pooled_in_channels=1280,
            clip_seq=4,
            effnet_in_channels=16,
            pixel_mapper_in_channels=3,
            kernel_size=3,
            dropout=[0, 0, 0.1, 0.1],
            self_attn=True,
            timestep_conditioning_type=["sca"],
        )

    if is_accelerate_available():
        load_model_dict_into_meta(decoder, decoder_state_dict)
    else:
        decoder.load_state_dict(decoder_state_dict)
187

188
189
    # VQGAN from Wuerstchen-V2
    vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    # Decoder pipeline
    decoder_pipeline = StableCascadeDecoderPipeline(
        decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
    )
    decoder_pipeline.to(dtype).save_pretrained(
        args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
    )

if args.save_combined:
    # Stable Cascade combined pipeline
    stable_cascade_pipeline = StableCascadeCombinedPipeline(
        # Decoder
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        decoder=decoder,
        scheduler=scheduler,
        vqgan=vqmodel,
        # Prior
        prior_text_encoder=text_encoder,
        prior_tokenizer=tokenizer,
        prior_prior=prior_model,
        prior_scheduler=scheduler,
        prior_image_encoder=image_encoder,
        prior_feature_extractor=feature_extractor,
    )
    stable_cascade_pipeline.to(dtype).save_pretrained(
        args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
    )