app.py 10.8 KB
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import numpy as np
import random
import os
import sys

from diffusers.utils import load_image
from diffusers import EulerDiscreteScheduler

from huggingface_hub import hf_hub_download
import gradio as gr

from photomaker import PhotoMakerStableDiffusionXLPipeline

from style_template import styles
from aspect_ratio_template import aspect_ratios

# global variable
base_model_path = 'SG161222/RealVisXL_V4.0'
try:
    if torch.cuda.is_available():
        device = "cuda"
    elif sys.platform == "darwin" and torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
except:
    device = "cpu"

MAX_SEED = np.iinfo(np.int32).max
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Photographic (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]

torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
if device == "mps":
    torch_dtype = torch.float16

pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
41
    base_model_path,
chenpangpang's avatar
chenpangpang committed
42
    torch_dtype=torch_dtype,
43
    use_safetensors=True,
chenpangpang's avatar
chenpangpang committed
44
45
46
47
48
    variant="fp16",
    # local_files_only=True,
).to(device)

pipe.load_photomaker_adapter(
49
    "TencentARC/PhotoMaker",
chenpangpang's avatar
chenpangpang committed
50
    subfolder="",
51
    weight_name="photomaker-v1.bin",
chenpangpang's avatar
chenpangpang committed
52
53
54
55
56
57
58
59
60
61
    trigger_word="img",
    pm_version="v1",
)
pipe.id_encoder.to(device)

pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
# pipe.set_adapters(["photomaker"], adapter_weights=[1.0])
pipe.fuse_lora()
pipe.to(device)

chenpangpang's avatar
chenpangpang committed
62

63
64
def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, style_name, num_steps,
                   style_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
chenpangpang's avatar
chenpangpang committed
65
66
67
68
69
70
71
72
73
74
75
    # check the trigger word
    image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word)
    input_ids = pipe.tokenizer.encode(prompt)
    if image_token_id not in input_ids:
        raise gr.Error(f"Cannot find the trigger word '{pipe.trigger_word}' in text prompt! Please refer to step 2️⃣")

    if input_ids.count(image_token_id) > 1:
        raise gr.Error(f"Cannot use multiple trigger words '{pipe.trigger_word}' in text prompt!")

    # determine output dimensions by the aspect ratio
    output_w, output_h = aspect_ratios[aspect_ratio_name]
76
    print(f"Generate image using aspect ratio [{aspect_ratio_name}] => {output_w} x {output_h}")
chenpangpang's avatar
chenpangpang committed
77
78
79
80
81
82
83
84
85
86

    # apply the style template
    prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)

    if upload_images is None:
        raise gr.Error(f"Cannot find any input face image! Please refer to step 1️⃣")

    input_id_images = []
    for img in upload_images:
        input_id_images.append(load_image(img))
87

chenpangpang's avatar
chenpangpang committed
88
89
90
    generator = torch.Generator(device=device).manual_seed(seed)

    print("Start inference...")
91
    print(f"Prompt: {prompt}, \n Neg Prompt: {negative_prompt}")
chenpangpang's avatar
chenpangpang committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
    if start_merge_step > 30:
        start_merge_step = 30
    print(start_merge_step)
    images = pipe(
        prompt=prompt,
        width=output_w,
        height=output_h,
        input_id_images=input_id_images,
        negative_prompt=negative_prompt,
        num_images_per_prompt=num_outputs,
        num_inference_steps=num_steps,
        start_merge_step=start_merge_step,
        generator=generator,
        guidance_scale=guidance_scale,
    ).images
    return images, gr.update(visible=True)

110

chenpangpang's avatar
chenpangpang committed
111
112
113
def swap_to_gallery(images):
    return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)

114

chenpangpang's avatar
chenpangpang committed
115
116
117
def upload_example_to_gallery(images, prompt, style, negative_prompt):
    return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)

118

chenpangpang's avatar
chenpangpang committed
119
120
def remove_back_to_files():
    return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
121
122


chenpangpang's avatar
chenpangpang committed
123
124
125
def remove_tips():
    return gr.update(visible=False)

126

chenpangpang's avatar
chenpangpang committed
127
128
129
130
131
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed

132

chenpangpang's avatar
chenpangpang committed
133
134
135
136
def apply_style(style_name: str, positive: str, negative: str = ""):
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive), n + ' ' + negative

137

chenpangpang's avatar
chenpangpang committed
138
139
140
141
142
def get_image_path_list(folder_name):
    image_basename_list = os.listdir(folder_name)
    image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
    return image_path_list

143

chenpangpang's avatar
chenpangpang committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def get_example():
    case = [
        [
            get_image_path_list('./examples/scarletthead_woman'),
            "instagram photo, portrait photo of a woman img, colorful, perfect face, natural skin, hard shadows, film grain",
            "(No style)",
            "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth",
        ],
        [
            get_image_path_list('./examples/newton_man'),
            "sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain",
            "(No style)",
            "(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth",
        ],
    ]
    return case

161

chenpangpang's avatar
chenpangpang committed
162
163
164
165
166
### Description and style
logo = r"""
<center><img src='https://photo-maker.github.io/assets/logo.png' alt='PhotoMaker logo' style="width:80px; margin-bottom:10px"></center>
"""
title = r"""
chenpangpang's avatar
chenpangpang committed
167
<h1 align="center">PhotoMaker:通过少量样本提取的个性化信息,在自然语言描述引导下,生成逼真的照片或者艺术化的图片</h1>
chenpangpang's avatar
chenpangpang committed
168
169
170
"""

description = r"""
chenpangpang's avatar
chenpangpang committed
171
172
173
174
175
个性化步骤:<br>
1️⃣ 上传您想要自定义的某个人的图像。一张或多张图片都行,建议多张。此工具不进行人脸检测,上传图像中的人脸应该占据图像的大部分。<br>
2️⃣ 输入文本提示符,确保按照您想要自定义的类单词使用触发词: `img`, 例如: `man img` 、 `woman img` 或 `girl img`。<br>
3️⃣ 选择您喜欢的风格模板。<br>
4️⃣ 单击提交按钮开始自定义。
chenpangpang's avatar
chenpangpang committed
176
177
178
"""

tips = r"""
chenpangpang's avatar
chenpangpang committed
179
180
181
### 使用技巧
1. 上传多张要定制的人的照片,以**提高身份识别精度**。如果输入是亚洲面孔,也许可以考虑在类单词之前添加“asian”,例如`asian woman img`
2. 为了**更快**的速度,减少生成的图像数量和采样步骤。但是,请注意,减少采样步骤可能会降低ID保真度。
chenpangpang's avatar
chenpangpang committed
182
183
"""
# We have provided some generate examples and comparisons at: [this website]().
184
# 3. Don't make the prompt too long, as we will trim it if it exceeds 77 tokens.
chenpangpang's avatar
chenpangpang committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# 4. When generating realistic photos, if it's not real enough, try switching to our other gradio application [PhotoMaker-Realistic]().

css = '''
.gradio-container {width: 85% !important}
'''
with gr.Blocks(css=css) as demo:
    gr.Markdown(logo)
    gr.Markdown(title)
    gr.Markdown(description)
    # gr.DuplicateButton(
    #     value="Duplicate Space for private use ",
    #     elem_id="duplicate-button",
    #     visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    # )
    with gr.Row():
        with gr.Column():
            files = gr.Files(
chenpangpang's avatar
chenpangpang committed
202
                label="上传/选择一张或多张人脸照片",
203
204
                file_types=["image"]
            )
chenpangpang's avatar
chenpangpang committed
205
            uploaded_files = gr.Gallery(label="你的图片", visible=False, columns=5, rows=1, height=200)
chenpangpang's avatar
chenpangpang committed
206
            with gr.Column(visible=False) as clear_button:
chenpangpang's avatar
chenpangpang committed
207
                remove_and_reupload = gr.ClearButton(value="移除并重新上传", components=files, size="sm")
chenpangpang's avatar
chenpangpang committed
208
            prompt = gr.Textbox(label="Prompt",
chenpangpang's avatar
chenpangpang committed
209
                                info="尝试类似'a photo of a man/woman img'的词, 'img'是触发词,必须包含",
210
                                placeholder="A photo of a [man/woman img]...")
chenpangpang's avatar
chenpangpang committed
211
212
            style = gr.Dropdown(label="风格", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
            aspect_ratio = gr.Dropdown(label="输出纵横比", choices=ASPECT_RATIO_LABELS,
213
                                       value=DEFAULT_ASPECT_RATIO)
chenpangpang's avatar
chenpangpang committed
214
            submit = gr.Button("提交")
chenpangpang's avatar
chenpangpang committed
215

chenpangpang's avatar
chenpangpang committed
216
            with gr.Accordion(open=False, label="高级选项"):
chenpangpang's avatar
chenpangpang committed
217
                negative_prompt = gr.Textbox(
218
                    label="Negative Prompt",
chenpangpang's avatar
chenpangpang committed
219
220
221
                    placeholder="low quality",
                    value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
                )
222
                num_steps = gr.Slider(
chenpangpang's avatar
chenpangpang committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
                    label="Number of sample steps",
                    minimum=20,
                    maximum=100,
                    step=1,
                    value=50,
                )
                style_strength_ratio = gr.Slider(
                    label="Style strength (%)",
                    minimum=15,
                    maximum=50,
                    step=1,
                    value=20,
                )
                num_outputs = gr.Slider(
                    label="Number of output images",
                    minimum=1,
                    maximum=4,
                    step=1,
                    value=2,
                )
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.1,
                    maximum=10.0,
                    step=0.1,
                    value=5,
                )
                seed = gr.Slider(
                    label="Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=0,
                )
                randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        with gr.Column():
chenpangpang's avatar
chenpangpang committed
259
260
            gallery = gr.Gallery(label="生成图片")
            usage_tips = gr.Markdown(label="使用技巧", value=tips, visible=False)
chenpangpang's avatar
chenpangpang committed
261
262
263
264
265
266

        files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
        remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])

        submit.click(
            fn=remove_tips,
267
            outputs=usage_tips,
chenpangpang's avatar
chenpangpang committed
268
269
270
271
272
273
274
275
        ).then(
            fn=randomize_seed_fn,
            inputs=[seed, randomize_seed],
            outputs=seed,
            queue=False,
            api_name=False,
        ).then(
            fn=generate_image,
276
277
            inputs=[files, prompt, negative_prompt, aspect_ratio, style, num_steps, style_strength_ratio, num_outputs,
                    guidance_scale, seed],
chenpangpang's avatar
chenpangpang committed
278
279
280
281
282
283
284
285
286
287
            outputs=[gallery, usage_tips]
        )

    gr.Examples(
        examples=get_example(),
        inputs=[files, prompt, style, negative_prompt],
        run_on_click=True,
        fn=upload_example_to_gallery,
        outputs=[uploaded_files, clear_button, files],
    )
288

289
demo.launch(server_name='0.0.0.0', share=True)