Commit e7dc03db authored by chenpangpang's avatar chenpangpang
Browse files

feat: 修改app.py避免重新下载模型

parent 06927f3c
......@@ -34,9 +34,6 @@ DEFAULT_STYLE_NAME = "Photographic (Default)"
ASPECT_RATIO_LABELS = list(aspect_ratios)
DEFAULT_ASPECT_RATIO = ASPECT_RATIO_LABELS[0]
# download PhotoMaker checkpoint to cache
photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
if device == "mps":
torch_dtype = torch.float16
......@@ -50,9 +47,9 @@ pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
).to(device)
pipe.load_photomaker_adapter(
os.path.dirname(photomaker_ckpt),
"TencentARC/PhotoMaker",
subfolder="",
weight_name=os.path.basename(photomaker_ckpt),
weight_name="photomaker-v1.bin",
trigger_word="img",
pm_version="v1",
)
......@@ -63,8 +60,10 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.fuse_lora()
pipe.to(device)
@spaces.GPU
def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, style_name, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, style_name, num_steps,
style_strength_ratio, num_outputs, guidance_scale, seed, progress=gr.Progress(track_tqdm=True)):
# check the trigger word
image_token_id = pipe.tokenizer.convert_tokens_to_ids(pipe.trigger_word)
input_ids = pipe.tokenizer.encode(prompt)
......@@ -110,32 +109,40 @@ def generate_image(upload_images, prompt, negative_prompt, aspect_ratio_name, st
).images
return images, gr.update(visible=True)
def swap_to_gallery(images):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def upload_example_to_gallery(images, prompt, style, negative_prompt):
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
def remove_back_to_files():
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
def remove_tips():
return gr.update(visible=False)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def apply_style(style_name: str, positive: str, negative: str = ""):
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + ' ' + negative
def get_image_path_list(folder_name):
image_basename_list = os.listdir(folder_name)
image_path_list = sorted([os.path.join(folder_name, basename) for basename in image_basename_list])
return image_path_list
def get_example():
case = [
[
......@@ -153,6 +160,7 @@ def get_example():
]
return case
### Description and style
logo = r"""
<center><img src='https://photo-maker.github.io/assets/logo.png' alt='PhotoMaker logo' style="width:80px; margin-bottom:10px"></center>
......@@ -234,7 +242,8 @@ with gr.Blocks(css=css) as demo:
info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.",
placeholder="A photo of a [man/woman img]...")
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS, value=DEFAULT_ASPECT_RATIO)
aspect_ratio = gr.Dropdown(label="Output aspect ratio", choices=ASPECT_RATIO_LABELS,
value=DEFAULT_ASPECT_RATIO)
submit = gr.Button("Submit")
with gr.Accordion(open=False, label="Advanced Options"):
......@@ -281,7 +290,7 @@ with gr.Blocks(css=css) as demo:
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
usage_tips = gr.Markdown(label="Usage tips of PhotoMaker", value=tips ,visible=False)
usage_tips = gr.Markdown(label="Usage tips of PhotoMaker", value=tips, visible=False)
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
......@@ -297,7 +306,8 @@ with gr.Blocks(css=css) as demo:
api_name=False,
).then(
fn=generate_image,
inputs=[files, prompt, negative_prompt, aspect_ratio, style, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed],
inputs=[files, prompt, negative_prompt, aspect_ratio, style, num_steps, style_strength_ratio, num_outputs,
guidance_scale, seed],
outputs=[gallery, usage_tips]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment