"ollama/llm/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh" did not exist on "ff27a8172ae24bbcff76eec4220c3081852c201b"
app.py 7.05 KB
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
3
4
5
6
7
8
import random
import torch
import cv2
import insightface
import gradio as gr
import numpy as np
import os
from huggingface_hub import snapshot_download
9
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
chenpangpang's avatar
chenpangpang committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import AutoencoderKL
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import EulerDiscreteScheduler
from PIL import Image
from insightface.app import FaceAnalysis
from insightface.data import get_image as ins_get_image

device = "cuda"
ckpt_dir = "Kwai-Kolors/Kolors"
ckpt_dir_faceid = "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus"

text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
29
30
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_faceid}/clip-vit-large-patch14-336',
                                                                   ignore_mismatched_sizes=True)
chenpangpang's avatar
chenpangpang committed
31
clip_image_encoder.to(device)
32
clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
chenpangpang's avatar
chenpangpang committed
33
34

pipe = StableDiffusionXLPipeline(
35
36
37
38
39
40
41
42
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler,
    face_clip_encoder=clip_image_encoder,
    face_clip_processor=clip_image_processor,
    force_zeros_for_empty_prompt=False,
chenpangpang's avatar
chenpangpang committed
43
44
)

45

chenpangpang's avatar
chenpangpang committed
46
class FaceInfoGenerator():
47
48
49
50
    def __init__(self, root_dir="./.insightface/"):
        self.app = FaceAnalysis(name='antelopev2', root=root_dir,
                                providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.app.prepare(ctx_id=0, det_size=(640, 640))
chenpangpang's avatar
chenpangpang committed
51
52
53
54
55
56
57

    def get_faceinfo_one_img(self, face_image):
        face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))

        if len(face_info) == 0:
            face_info = None
        else:
58
59
            face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[
                -1]  # only use the maximum face
chenpangpang's avatar
chenpangpang committed
60
61
        return face_info

62

chenpangpang's avatar
chenpangpang committed
63
64
def face_bbox_to_square(bbox):
    ## l, t, r, b to square l, t, r, b
65
    l, t, r, b = bbox
chenpangpang's avatar
chenpangpang committed
66
67
68
69
70
71
72
73
74
75
76
77
    cent_x = (l + r) / 2
    cent_y = (t + b) / 2
    w, h = r - l, b - t
    r = max(w, h) / 2

    l0 = cent_x - r
    r0 = cent_x + r
    t0 = cent_y - r
    b0 = cent_y + r

    return [l0, t0, r0, b0]

78

chenpangpang's avatar
chenpangpang committed
79
80
81
82
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
face_info_generator = FaceInfoGenerator()

83
84
85
86
87
88
89
90
91

def infer(prompt,
          image=None,
          negative_prompt="nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯",
          seed=66,
          randomize_seed=False,
          guidance_scale=5.0,
          num_inference_steps=50
          ):
chenpangpang's avatar
chenpangpang committed
92
93
94
95
96
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    global pipe
    pipe = pipe.to(device)
97
    pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
chenpangpang's avatar
chenpangpang committed
98
    scale = 0.8
99
    pipe.set_face_fidelity_scale(scale)
chenpangpang's avatar
chenpangpang committed
100
101
102
103
104
105
106

    face_info = face_info_generator.get_faceinfo_one_img(image)
    face_bbox_square = face_bbox_to_square(face_info["bbox"])
    crop_image = image.crop(face_bbox_square)
    crop_image = crop_image.resize((336, 336))
    crop_image = [crop_image]
    face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
107
    face_embeds = face_embeds.to(device, dtype=torch.float16)
chenpangpang's avatar
chenpangpang committed
108
109

    image = pipe(
110
111
112
113
114
115
116
117
118
119
        prompt=prompt,
        negative_prompt=negative_prompt,
        height=1024,
        width=1024,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=1,
        generator=generator,
        face_crop_image=crop_image,
        face_insightface_embeds=face_embeds
chenpangpang's avatar
chenpangpang committed
120
121
122
123
124
125
126
    ).images[0]

    return image, seed


examples = [
    ["穿着晚礼服,在星光下的晚宴场景中,烛光闪闪,整个场景洋溢着浪漫而奢华的氛围", "image/image1.png"],
127
128
    ["西部牛仔,牛仔帽,荒野大镖客,背景是西部小镇,仙人掌,,日落余晖, 暖色调, 使用XT4胶片拍摄, 噪点, 晕影, 柯达胶卷,复古",
     "image/image2.png"]
chenpangpang's avatar
chenpangpang committed
129
130
]

131
css = """
chenpangpang's avatar
chenpangpang committed
132
133
134
135
136
137
138
139
140
141
142
143
144
#col-left {
    margin: 0 auto;
    max-width: 600px;
}
#col-right {
    margin: 0 auto;
    max-width: 750px;
}
#button {
    color: blue;
}
"""

145

chenpangpang's avatar
chenpangpang committed
146
147
148
149
150
def load_description(fp):
    with open(fp, 'r', encoding='utf-8') as f:
        content = f.read()
    return content

151

chenpangpang's avatar
chenpangpang committed
152
153
154
155
156
157
with gr.Blocks(css=css) as Kolors:
    gr.HTML(load_description("assets/title.md"))
    with gr.Row():
        with gr.Column(elem_id="col-left"):
            with gr.Row():
                prompt = gr.Textbox(
chenpangpang's avatar
chenpangpang committed
158
159
                    label="提示词",
                    placeholder="请输入提示词",
chenpangpang's avatar
chenpangpang committed
160
161
162
                    lines=2
                )
            with gr.Row():
chenpangpang's avatar
chenpangpang committed
163
164
                image = gr.Image(label="图像", type="pil")
            with gr.Accordion("高级设置", open=False):
chenpangpang's avatar
chenpangpang committed
165
                negative_prompt = gr.Textbox(
chenpangpang's avatar
chenpangpang committed
166
167
                    label="负面提示词",
                    placeholder="请输入负面提示词",
chenpangpang's avatar
chenpangpang committed
168
169
170
171
172
173
174
175
176
                    visible=True,
                )
                seed = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=0,
                )
chenpangpang's avatar
chenpangpang committed
177
                randomize_seed = gr.Checkbox(label="随机seed", value=True)
chenpangpang's avatar
chenpangpang committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
                with gr.Row():
                    guidance_scale = gr.Slider(
                        label="Guidance scale",
                        minimum=0.0,
                        maximum=10.0,
                        step=0.1,
                        value=5.0,
                    )
                    num_inference_steps = gr.Slider(
                        label="Number of inference steps",
                        minimum=10,
                        maximum=50,
                        step=1,
                        value=25,
                    )
            with gr.Row():
chenpangpang's avatar
chenpangpang committed
194
                button = gr.Button("运行", elem_id="button")
195

chenpangpang's avatar
chenpangpang committed
196
        with gr.Column(elem_id="col-right"):
chenpangpang's avatar
chenpangpang committed
197
198
            result = gr.Image(label="输出", show_label=False)
            seed_used = gr.Number(label="使用的Seed")
199

chenpangpang's avatar
chenpangpang committed
200
201
    with gr.Row():
        gr.Examples(
202
203
204
205
206
            fn=infer,
            examples=examples,
            inputs=[prompt, image],
            outputs=[result, seed_used],
        )
chenpangpang's avatar
chenpangpang committed
207
208

    button.click(
209
210
211
        fn=infer,
        inputs=[prompt, image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps],
        outputs=[result, seed_used]
chenpangpang's avatar
chenpangpang committed
212
213
214
    )

Kolors.queue().launch(server_name="0.0.0.0", share=True)