"vscode:/vscode.git/clone" did not exist on "cafa6a9e29f3e99c67a1028f8ca779d439bc0689"
Unverified Commit 6ec44fc1 authored by Yorick's avatar Yorick Committed by GitHub
Browse files

Update replicate code and readme (#73)



* Add working cog demo

* Add safety checker

* Print seed when doing a prediction

* Update README.md

* Update for upstream changes, update torch==2.1.1, pget model weights

* Add up to four inputs and outputs in cog demo

And clean up predict.py based on review feedback

* Add some debug output for users

* Improve cog parameters based on feedback

* Strip trigger_word from prompt using tokenizer

* Fix guidance scale bug, allow no classifier free guidance in pipeline.

* Update README.md

* Disallow trigger word in negative prompt

---------
Co-authored-by: default avatarJudith van Stegeren <judith@replicate.com>
Co-authored-by: default avatarJudith van Stegeren <690008+jd7h@users.noreply.github.com>
parent 5c4839a7
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
# Exclude Git files
.git
.github
.gitignore
# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache
# Exclude Python virtual environment
/venv
# Exclude output files
/outputs
output*.png
# Exclude models cache
/models
......@@ -9,7 +9,9 @@
## PhotoMaker: Customizing Realistic Human Photos via Stacked ID Embedding [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)](https://huggingface.co/papers/2312.04461)
[[Paper](https://huggingface.co/papers/2312.04461)] &emsp; [[Project Page](https://photo-maker.github.io)] &emsp; [[Model Card](https://huggingface.co/TencentARC/PhotoMaker)] <br>
[[🤗 Demo (Realistic)](https://huggingface.co/spaces/TencentARC/PhotoMaker)] &emsp; [[🤗 Demo (Stylization)](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style)] <be>
[[🤗 Demo (Realistic)](https://huggingface.co/spaces/TencentARC/PhotoMaker)] &emsp; [[🤗 Demo (Stylization)](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style)] <br>
[[Replicate Demo (Realistic)](https://replicate.com/jd7h/photomaker)] &emsp; [[Replicate Demo (Stylization)](https://replicate.com/yorickvp/photomaker-style)] <be>
If the ID fidelity is not enough for you, please try our [stylization application](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style), you may be pleasantly surprised.
</div>
......@@ -199,7 +201,9 @@ If you want to run it on MAC, you should follow [this Instruction](MacGPUEnv.md)
# Related Resources
### Replicate demo of PhotoMaker:
[Demo link](https://replicate.com/jd7h/photomaker) by [@yorickvP](https://github.com/yorickvP), transfer PhotoMaker to replicate.
1. [Demo link](https://replicate.com/jd7h/photomaker), run PhotoMaker on replicate.
2. [Demo link (style version)](https://replicate.com/yorickvp/photomaker-style).
### Windows version of PhotoMaker:
1. [bmaltais/PhotoMaker](https://github.com/bmaltais/PhotoMaker/tree/v1.0.1) by [@bmaltais](https://github.com/bmaltais), easy to deploy PhotoMaker on Windows. The description can be found in [this link](https://github.com/TencentARC/PhotoMaker/discussions/36#discussioncomment-8156199).
2. [sdbds/PhotoMaker-for-windows](https://github.com/sdbds/PhotoMaker-for-windows/tree/windows) by [@sdbds](https://github.com/bmaltais).
......
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
build:
# set to true if your model requires a GPU
gpu: true
cuda: "11.7"
python_version: "3.8"
# python version in the form '3.11' or '3.11.4'
python_version: "3.11"
python_packages:
- "torch==2.0.1"
- "torchvision==0.15.2"
- "accelerate==0.26.1"
- "diffusers==0.25.0"
- "transformers==4.36.2"
- "huggingface-hub==0.20.2"
- "numpy"
- "accelerate"
- "safetensors"
- "omegaconf"
- "peft"
- "numpy==1.24.4"
- "omegaconf==2.3.0"
- "peft==0.7.1"
- "safetensors==0.4.1"
- "torch==2.1.1"
- "torchvision==0.16.1"
- "transformers==4.36.2"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.5.6/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
image: r8.im/tencentarc/photomaker
{
"crop_size": 224,
"do_center_crop": true,
"do_convert_rgb": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 224
}
......@@ -322,7 +322,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
do_classifier_free_guidance = guidance_scale >= 1.0
assert do_classifier_free_guidance
......
......@@ -2,109 +2,238 @@
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path
import torch
import numpy as np
import random
import os
from PIL import Image
import logging
import time
from typing import List
import shutil
import subprocess
import time
os.environ["HF_HUB_CACHE"] = "models"
os.environ["HF_HUB_CACHE_OFFLINE"] = "true"
from diffusers.utils import load_image
from diffusers import EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from huggingface_hub import hf_hub_download
from transformers import CLIPImageProcessor
from photomaker import PhotoMakerStableDiffusionXLPipeline
from gradio_demo.style_template import styles
MAX_SEED = np.iinfo(np.int32).max
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Photographic (Default)"
FEATURE_EXTRACTOR = "./feature-extractor"
SAFETY_CACHE = "./models/safety-cache"
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
from photomaker.pipeline import PhotoMakerStableDiffusionXLPipeline
BASE_MODEL_URL = "https://weights.replicate.delivery/default/SG161222--RealVisXL_V3.0-11ee564ebf4bd96d90ed5d473cb8e7f2e6450bcf.tar"
BASE_MODEL_PATH = "models/SG161222/RealVisXL_V3.0"
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
logger = logging.getLogger(__name__)
PHOTOMAKER_URL = "https://weights.replicate.delivery/default/TencentARC--PhotoMaker/photomaker-v1.bin"
PHOTOMAKER_PATH = "models/photomaker-v1.bin"
base_model_path = 'SG161222/RealVisXL_V3.0'
photomaker_path = 'release_model/photomaker-v1.bin'
device = "cuda"
def download_weights(url, dest, extract=True):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
args = ["pget"]
if extract:
args.append("-x")
subprocess.check_call(args + [url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
# utility function for style templates
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + " " + negative
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
start = time.time()
logger.info("Loading model...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# download PhotoMaker checkpoint to cache
# if we already have the model, this doesn't do anything
if not os.path.exists(PHOTOMAKER_PATH):
download_weights(PHOTOMAKER_URL, PHOTOMAKER_PATH, extract=False)
if not os.path.exists(BASE_MODEL_PATH):
download_weights(BASE_MODEL_URL, BASE_MODEL_PATH)
print("Loading safety checker...")
if not os.path.exists(SAFETY_CACHE):
download_weights(SAFETY_URL, SAFETY_CACHE)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
self.pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16"
).to(device)
BASE_MODEL_PATH,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
).to(self.device)
self.pipe.load_photomaker_adapter(
os.path.dirname(photomaker_path),
subfolder="",
weight_name=os.path.basename(photomaker_path),
trigger_word="img"
)
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
os.path.dirname(PHOTOMAKER_PATH),
subfolder="",
weight_name=os.path.basename(PHOTOMAKER_PATH),
trigger_word="img",
)
self.pipe.id_encoder.to(self.device)
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config
)
self.pipe.fuse_lora()
logger.info(f"Loaded model in {time.time() - start:.06}s")
def _load_image(self, path):
shutil.copyfile(path, "/tmp/image.png")
return load_image("/tmp/image.png").convert("RGB")
@torch.inference_mode()
def predict(
self,
input_image: Path = Input(
description="The input image, for example a photo of your face."
),
input_image2: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image3: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image4: Path = Input(
description="Additional input image (optional)",
default=None
),
prompt: str = Input(
description="Input prompt",
default="sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain"
description="Prompt. Example: 'a photo of a man/woman img'. The phrase 'img' is the trigger word.",
default="A photo of a person img",
),
style_name: str = Input(
description="Style template. The style template will add a style-specific prompt and negative prompt to the user's prompt.",
choices=STYLE_NAMES,
default=DEFAULT_STYLE_NAME,
),
negative_prompt: str = Input(
description="Negative Input prompt",
default="(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth"
description="Negative Prompt. The negative prompt should NOT contain the trigger word.",
default="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
),
image: Path = Input(
description="Input image for img2img or inpaint mode",
default=None,
num_steps: int = Input(
description="Number of sample steps", default=20, ge=1, le=100
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
style_strength_ratio: float = Input(
description="Style strength (%)", default=20, ge=15, le=50
),
num_outputs: int = Input(
description="Number of images to output.",
ge=1,
le=4,
default=1,
description="Number of output images", default=1, ge=1, le=4
),
guidance_scale: float = Input(
description="Guidance scale. A guidance scale of 1 corresponds to doing no classifier free guidance.", default=5, ge=1, le=10.0
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=40
seed: int = Input(description="Seed. Leave blank to use a random number", default=None, ge=0, le=MAX_SEED),
disable_safety_checker: bool = Input(
description="Disable safety checker for generated images.",
default=False
)
) -> List[Path]:
) -> list[Path]:
"""Run a single prediction on the model"""
# remove old outputs
output_folder = Path('outputs')
if output_folder.exists():
shutil.rmtree(output_folder)
os.makedirs(str(output_folder), exist_ok=False)
# randomize seed if necessary
if seed is None:
seed = int.from_bytes(os.urandom(4), "big")
logger.info(f"Using seed: {seed}")
generator = torch.Generator("cuda").manual_seed(seed)
seed = random.randint(0, MAX_SEED)
print(f"Using seed {seed}...")
# check the prompt for the trigger word
image_token_id = self.pipe.tokenizer.convert_tokens_to_ids(self.pipe.trigger_word)
input_ids = self.pipe.tokenizer.encode(prompt)
if image_token_id not in input_ids:
raise ValueError(
f"Cannot find the trigger word '{self.pipe.trigger_word}' in text prompt!")
style_strength_ratio = 20
start_merge_step = int(float(style_strength_ratio) / 100 * num_inference_steps)
if input_ids.count(image_token_id) > 1:
raise ValueError(
f"Cannot use multiple trigger words '{self.pipe.trigger_word}' in text prompt!"
)
# check the negative prompt for the trigger word
if negative_prompt:
negative_prompt_ids = self.pipe.tokenizer.encode(negative_prompt)
if image_token_id in negative_prompt_ids:
raise ValueError(
f"Cannot use trigger word '{self.pipe.trigger_word}' in negative prompt!"
)
# apply the style template
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
# load the input images
input_id_images = []
for maybe_image in [input_image, input_image2, input_image3, input_image4]:
if maybe_image:
print(f"Loading image {maybe_image}...")
input_id_images.append(load_image(str(maybe_image)))
print(f"Setting seed...")
generator = torch.Generator(device=self.device).manual_seed(seed)
print("Start inference...")
print(f"[Debug] Prompt: {prompt}")
print(f"[Debug] Neg Prompt: {negative_prompt}")
start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
if start_merge_step > 30:
start_merge_step = 30
print(f"Start merge step: {start_merge_step}")
images = self.pipe(
prompt=prompt,
input_id_images=[self._load_image(image)],
input_id_images=input_id_images,
negative_prompt=negative_prompt,
num_images_per_prompt=num_outputs,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_outputs,
num_inference_steps=num_steps,
start_merge_step=start_merge_step,
generator=generator,
guidance_scale=guidance_scale,
).images
if not disable_safety_checker:
print(f"Running safety checker...")
_, has_nsfw_content = self.run_safety_checker(images)
# save results to file
print(f"Saving images to file...")
output_paths = []
for i, image in enumerate(images):
output_path = f"/tmp/out-{i}.png"
if not disable_safety_checker:
if has_nsfw_content[i]:
print(f"NSFW content detected in image {i}")
continue
output_path = output_folder / f"image_{i}.png"
image.save(output_path)
output_paths.append(Path(output_path))
output_paths.append(output_path)
return [Path(p) for p in output_paths]
return output_paths
\ No newline at end of file
def run_safety_checker(self, image):
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
"cuda"
)
np_image = [np.array(val) for val in image]
image, has_nsfw_concept = self.safety_checker(
images=np_image,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
return image, has_nsfw_concept
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment