Commit e7dc03db authored by chenpangpang's avatar chenpangpang
Browse files

feat: 修改app.py避免重新下载模型

parent 06927f3c
......@@ -34,25 +34,22 @@ DEFAULT_STYLE_NAME = "Photographic (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
# download PhotoMaker checkpoint to cache
photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
if device == "mps":
torch_dtype = torch.float16
pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
base_model_path,
base_model_path,
torch_dtype=torch_dtype,
use_safetensors=True,
use_safetensors=True,
variant="fp16",
# local_files_only=True,
).to(device)
pipe.load_photomaker_adapter(
os.path.dirname(photomaker_ckpt),
"TencentARC/PhotoMaker",
subfolder="",
weight_name=os.path.basename(photomaker_ckpt),
weight_name="photomaker-v1.bin",
trigger_word="img",
pm_version="v1",
)
......@@ -63,8 +60,10 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.fuse_lora()
pipe.to(device)
@spaces.GPU
def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, style_name, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, style_name, num_steps,
style_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
# check the trigger word
image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word)
input_ids = pipe.tokenizer.encode(prompt)
......@@ -87,7 +86,7 @@ def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, st
input_id_images = []
for img in upload_images:
input_id_images.append(load_image(img))
generator = torch.Generator(device=device).manual_seed(seed)
print("Start inference...")
......@@ -110,32 +109,40 @@ def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, st
).images
return images, gr.update(visible=True)
def swap_to_gallery(images):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def upload_example_to_gallery(images, prompt, style, negative_prompt):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def remove_back_to_files():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
def remove_tips():
return gr.update(visible=False)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def apply_style(style_name: str, positive: str, negative: str = ""):
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + ' ' + negative
def get_image_path_list(folder_name):
image_basename_list = os.listdir(folder_name)
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
return image_path_list
def get_example():
case = [
[
......@@ -153,6 +160,7 @@ def get_example():
]
return case
### Description and style
logo = r"""
<center><img src='https://photo-maker.github.io/assets/logo.png' alt='PhotoMaker logo' style="width:80px; margin-bottom:10px"></center>
......@@ -206,7 +214,7 @@ tips = r"""
3. For **faster** speed, reduce the number of generated images and sampling steps. However, please note that reducing the sampling steps may compromise the ID fidelity.
"""
# We have provided some generate examples and comparisons at: [this website]().
# 3. Don't make the prompt too long, as we will trim it if it exceeds 77 tokens.
# 3. Don't make the prompt too long, as we will trim it if it exceeds 77 tokens.
# 4. When generating realistic photos, if it's not real enough, try switching to our other gradio application [PhotoMaker-Realistic]().
css = '''
......@@ -224,26 +232,27 @@ with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
files = gr.Files(
label="Drag (Select) 1 or more photos of your face",
file_types=["image"]
)
label="Drag (Select) 1 or more photos of your face",
file_types=["image"]
)
uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
with gr.Column(visible=False) as clear_button:
remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
prompt = gr.Textbox(label="Prompt",
info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.",
placeholder="A photo of a [man/woman img]...")
info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.",
placeholder="A photo of a [man/woman img]...")
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS,
value=DEFAULT_ASPECT_RATIO)
submit = gr.Button("Submit")
with gr.Accordion(open=False, label="Advanced Options"):
negative_prompt = gr.Textbox(
label="Negative Prompt",
label="Negative Prompt",
placeholder="low quality",
value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
)
num_steps = gr.Slider(
num_steps = gr.Slider(
label="Number of sample steps",
minimum=20,
maximum=100,
......@@ -281,14 +290,14 @@ with gr.Blocks(css=css) as demo:
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
usage_tips = gr.Markdown(label="Usage tips of PhotoMaker", value=tips ,visible=False)
usage_tips = gr.Markdown(label="Usage tips of PhotoMaker", value=tips, visible=False)
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
submit.click(
fn=remove_tips,
outputs=usage_tips,
outputs=usage_tips,
).then(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
......@@ -297,7 +306,8 @@ with gr.Blocks(css=css) as demo:
api_name=False,
).then(
fn=generate_image,
inputs=[files, prompt, negative_prompt, aspect_ratio, style, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed],
inputs=[files, prompt, negative_prompt, aspect_ratio, style, num_steps, style_strength_ratio, num_outputs,
guidance_scale, seed],
outputs=[gallery, usage_tips]
)
......@@ -308,7 +318,7 @@ with gr.Blocks(css=css) as demo:
fn=upload_example_to_gallery,
outputs=[uploaded_files, clear_button, files],
)
gr.Markdown(article)
demo.launch(server_name='0.0.0.0', share=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment