sampleui.py 4.64 KB
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import torch
import gradio as gr
# from PIL import Image
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler

root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# Initialize global variables for models and pipeline
text_encoder = None
tokenizer = None
vae = None
scheduler = None
unet = None
pipe = None


def remove_folder(path):
    if os.path.exists(path):
        if os.path.isfile(path) or os.path.islink(path):
            os.remove(path)
        else:
            for filename in os.listdir(path):
                remove_folder(os.path.join(path, filename))
            os.rmdir(path)


def load_models():
    global text_encoder, tokenizer, vae, scheduler, unet, pipe

    if text_encoder is None:
        ckpt_dir = f'{root_dir}/weights/Kolors'

        # Load the text encoder on CPU (this speeds stuff up 2x)
        text_encoder = ChatGLMModel.from_pretrained(
            f'{ckpt_dir}/text_encoder',
            torch_dtype=torch.float16).to('cpu').half()
        tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')

        # Load the VAE and UNet on GPU
        vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to('cuda')
        scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
        unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to('cuda')

        # Prepare the pipeline
        pipe = StableDiffusionXLPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            force_zeros_for_empty_prompt=False)
        pipe = pipe.to("cuda")
        pipe.enable_model_cpu_offload()  # Enable offloading to balance CPU/GPU usage


def infer(prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt):
    load_models()

    if use_random_seed:
        seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()

    generator = torch.Generator(pipe.device).manual_seed(seed)
    images = pipe(
        prompt=prompt,
        height=height,
        width=width,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
        generator=generator
    ).images

    saved_images = []
    output_dir = f'{root_dir}/scripts/outputs'
    remove_folder(output_dir)
    os.makedirs(output_dir, exist_ok=True)

    for i, image in enumerate(images):
        file_path = os.path.join(output_dir, 'sample_test.jpg')
        base_name, ext = os.path.splitext(file_path)
        counter = 1
        while os.path.exists(file_path):
            file_path = f"{base_name}_{counter}{ext}"
            counter += 1
        image.save(file_path)
        saved_images.append(file_path)

    return saved_images


def gradio_interface():
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr.Markdown("## Kolors: Diffusion Model Gradio Interface")
                prompt = gr.Textbox(label="Prompt")
                use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
                seed = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, label="Seed", randomize=True, visible=False)
                use_random_seed.change(lambda x: gr.update(visible=not x), use_random_seed, seed)
                height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=1024)
                width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=1024)
                num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=50)
                guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0)
                num_images_per_prompt = gr.Slider(minimum=1, maximum=10, step=1, label="Images per Prompt", value=1)
                btn = gr.Button("Generate Image")

            with gr.Column():
                output_images = gr.Gallery(label="Output Images", elem_id="output_gallery")

        btn.click(
            fn=infer,
            inputs=[prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale,
                    num_images_per_prompt],
            outputs=output_images
        )

    return demo


if __name__ == '__main__':
    load_models()
    gradio_interface().launch(server_name='0.0.0.0')