run_sdxl_with_custom_components.py 4.85 KB
Newer Older
wangwf's avatar
init  
wangwf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import copy
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from migraphx_diffusers import (MIGraphXAutoencoderKL, MIGraphXCLIPTextModel,
                                MIGraphXCLIPTextModelWithProjection, 
                                MIGraphXUNet2DConditionModel)
from transformers import AutoTokenizer
import numpy as np
import torch


def parse_args():
    from argparse import ArgumentParser
    parser = ArgumentParser(description="SDXL inference with migraphx backend")

    #=========================== mdoel load and compile ========================
    parser.add_argument(
        "-m", 
        "--model-dir",
        type=str,
        required=True,
        help="Path to local model directory.",
    )
    parser.add_argument(
        "--force-compile",
        action="store_true",
        default=False,
        help="Ignore existing .mxr files and override them",
    )
    parser.add_argument(
        "--img-size",
        type=int,
        default=1024,
        help="output image size",
    )
    parser.add_argument(
        "--num-images-per-prompt",
        type=int,
        default=1,
        help="The number of images to generate per prompt."
    )
    # --------------------------------------------------------------------------

    # =============================== generation ===============================
    parser.add_argument(
        "-p",
        "--prompt",
        type=str,
        required=True,
        help="Prompt for describe image content, style and so on."
    )
    parser.add_argument(
        "-n",
        "--negative-prompt",
        type=str,
        default=None,
        help="Negative prompt",
    )
    parser.add_argument(
        "-t",
        "--num-inference-steps",
        type=int,
        default=50,
        help="Number of iteration steps",
    )
    parser.add_argument(
        "--save-prefix",
        type=str,
        default="sdxl_output",
        help="Prefix of path for saving results",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Random seed",
    )
    # --------------------------------------------------------------------------

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    pipeline_dir = args.model_dir
    common_args = dict(
        batch=args.num_images_per_prompt,
        img_size=args.img_size,
        model_dtype='fp16',
        force_compile=args.force_compile,
    )
    text_encoder_args = copy.deepcopy(common_args)
    text_encoder_args['batch'] = 1

    # ========================== load migraphx mdoels ==========================
    text_encoder = MIGraphXCLIPTextModel.from_pretrained(
        pipeline_dir, subfolder="text_encoder", **text_encoder_args)
    text_encoder_2 = MIGraphXCLIPTextModelWithProjection.from_pretrained(
        pipeline_dir, subfolder="text_encoder_2", **text_encoder_args)
    unet = MIGraphXUNet2DConditionModel.from_pretrained(
        pipeline_dir, subfolder="unet", **common_args, 
        pipeline_class=StableDiffusionXLPipeline)
    vae = MIGraphXAutoencoderKL.from_pretrained(
        pipeline_dir, subfolder="vae_decoder", **common_args)
    # --------------------------------------------------------------------------

    # ============================ load torch models ===========================
    scheduler = EulerDiscreteScheduler.from_pretrained(
        pipeline_dir, subfolder="scheduler")
    tokenizer = AutoTokenizer.from_pretrained(
        pipeline_dir, subfolder="tokenizer")
    tokenizer_2 = AutoTokenizer.from_pretrained(
        pipeline_dir, subfolder="tokenizer_2")
    # --------------------------------------------------------------------------

    # create pipeline
    pipe = StableDiffusionXLPipeline(
            vae=vae,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            tokenizer=tokenizer,
            tokenizer_2=tokenizer_2,
            unet=unet,
            scheduler=scheduler,
            force_zeros_for_empty_prompt=True,
            add_watermarker=None,
    )
    pipe.to("cuda")
    pipe.to(torch.float16)

    # register configuration
    pipe.register_to_config(
        _mgx_models=["text_encoder", "text_encoder_2", "unet", "vae"])
    pipe.register_to_config(_batch=args.num_images_per_prompt)
    pipe.register_to_config(_img_height=args.img_size)
    pipe.register_to_config(_img_width=args.img_size)

    # generate images
    print("Generating image...")
    images = pipe(
        prompt=args.prompt, 
        negative_prompt=args.negative_prompt, 
        num_inference_steps=args.num_inference_steps,
        generator=torch.Generator("cuda").manual_seed(args.seed)
    ).images

    for i, image in enumerate(images):
        save_path = f"{args.save_prefix}_{i}.png"
        image.save(save_path)
        print(f"Generated image: {save_path}")


if __name__ == "__main__":
    main()