Unverified Commit 6ec44fc1 authored by Yorick's avatar Yorick Committed by GitHub
Browse files

Update replicate code and readme (#73)



* Add working cog demo

* Add safety checker

* Print seed when doing a prediction

* Update README.md

* Update for upstream changes, update torch==2.1.1, pget model weights

* Add up to four inputs and outputs in cog demo

And clean up predict.py based on review feedback

* Add some debug output for users

* Improve cog parameters based on feedback

* Strip trigger_word from prompt using tokenizer

* Fix guidance scale bug, allow no classifier free guidance in pipeline.

* Update README.md

* Disallow trigger word in negative prompt

---------
Co-authored-by: default avatarJudith van Stegeren <judith@replicate.com>
Co-authored-by: default avatarJudith van Stegeren <690008+jd7h@users.noreply.github.com>
parent 5c4839a7
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file
# Exclude Git files
.git
.github
.gitignore
# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache
# Exclude Python virtual environment
/venv
# Exclude output files
/outputs
output*.png
# Exclude models cache
/models
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
## PhotoMaker: Customizing Realistic Human Photos via Stacked ID Embedding [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)](https://huggingface.co/papers/2312.04461) ## PhotoMaker: Customizing Realistic Human Photos via Stacked ID Embedding [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)](https://huggingface.co/papers/2312.04461)
[[Paper](https://huggingface.co/papers/2312.04461)] &emsp; [[Project Page](https://photo-maker.github.io)] &emsp; [[Model Card](https://huggingface.co/TencentARC/PhotoMaker)] <br> [[Paper](https://huggingface.co/papers/2312.04461)] &emsp; [[Project Page](https://photo-maker.github.io)] &emsp; [[Model Card](https://huggingface.co/TencentARC/PhotoMaker)] <br>
[[🤗 Demo (Realistic)](https://huggingface.co/spaces/TencentARC/PhotoMaker)] &emsp; [[🤗 Demo (Stylization)](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style)] <be> [[🤗 Demo (Realistic)](https://huggingface.co/spaces/TencentARC/PhotoMaker)] &emsp; [[🤗 Demo (Stylization)](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style)] <br>
[[Replicate Demo (Realistic)](https://replicate.com/jd7h/photomaker)] &emsp; [[Replicate Demo (Stylization)](https://replicate.com/yorickvp/photomaker-style)] <be>
If the ID fidelity is not enough for you, please try our [stylization application](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style), you may be pleasantly surprised. If the ID fidelity is not enough for you, please try our [stylization application](https://huggingface.co/spaces/TencentARC/PhotoMaker-Style), you may be pleasantly surprised.
</div> </div>
...@@ -199,7 +201,9 @@ If you want to run it on MAC, you should follow [this Instruction](MacGPUEnv.md) ...@@ -199,7 +201,9 @@ If you want to run it on MAC, you should follow [this Instruction](MacGPUEnv.md)
# Related Resources # Related Resources
### Replicate demo of PhotoMaker: ### Replicate demo of PhotoMaker:
[Demo link](https://replicate.com/jd7h/photomaker) by [@yorickvP](https://github.com/yorickvP), transfer PhotoMaker to replicate. 1. [Demo link](https://replicate.com/jd7h/photomaker), run PhotoMaker on replicate.
2. [Demo link (style version)](https://replicate.com/yorickvp/photomaker-style).
### Windows version of PhotoMaker: ### Windows version of PhotoMaker:
1. [bmaltais/PhotoMaker](https://github.com/bmaltais/PhotoMaker/tree/v1.0.1) by [@bmaltais](https://github.com/bmaltais), easy to deploy PhotoMaker on Windows. The description can be found in [this link](https://github.com/TencentARC/PhotoMaker/discussions/36#discussioncomment-8156199). 1. [bmaltais/PhotoMaker](https://github.com/bmaltais/PhotoMaker/tree/v1.0.1) by [@bmaltais](https://github.com/bmaltais), easy to deploy PhotoMaker on Windows. The description can be found in [this link](https://github.com/TencentARC/PhotoMaker/discussions/36#discussioncomment-8156199).
2. [sdbds/PhotoMaker-for-windows](https://github.com/sdbds/PhotoMaker-for-windows/tree/windows) by [@sdbds](https://github.com/bmaltais). 2. [sdbds/PhotoMaker-for-windows](https://github.com/sdbds/PhotoMaker-for-windows/tree/windows) by [@sdbds](https://github.com/bmaltais).
......
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
build: build:
# set to true if your model requires a GPU
gpu: true gpu: true
cuda: "11.7"
python_version: "3.8" # python version in the form '3.11' or '3.11.4'
python_version: "3.11"
python_packages: python_packages:
- "torch==2.0.1" - "accelerate==0.26.1"
- "torchvision==0.15.2"
- "diffusers==0.25.0" - "diffusers==0.25.0"
- "transformers==4.36.2"
- "huggingface-hub==0.20.2" - "huggingface-hub==0.20.2"
- "numpy" - "numpy==1.24.4"
- "accelerate" - "omegaconf==2.3.0"
- "safetensors" - "peft==0.7.1"
- "omegaconf" - "safetensors==0.4.1"
- "peft" - "torch==2.1.1"
- "torchvision==0.16.1"
- "transformers==4.36.2"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.5.6/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
# predict.py defines how predictions are run on your model # predict.py defines how predictions are run on your model
predict: "predict.py:Predictor" predict: "predict.py:Predictor"
image: r8.im/tencentarc/photomaker
{
"crop_size": 224,
"do_center_crop": true,
"do_convert_rgb": true,
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 224
}
...@@ -322,7 +322,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline): ...@@ -322,7 +322,7 @@ class PhotoMakerStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale >= 1.0
assert do_classifier_free_guidance assert do_classifier_free_guidance
......
...@@ -2,109 +2,238 @@ ...@@ -2,109 +2,238 @@
# https://github.com/replicate/cog/blob/main/docs/python.md # https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path
import torch import torch
import numpy as np import numpy as np
import random import random
import os import os
from PIL import Image
import logging
import time
from typing import List
import shutil import shutil
import subprocess
import time
os.environ["HF_HUB_CACHE"] = "models"
os.environ["HF_HUB_CACHE_OFFLINE"] = "true"
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers import EulerDiscreteScheduler from diffusers import EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from huggingface_hub import hf_hub_download
from transformers import CLIPImageProcessor
from photomaker import PhotoMakerStableDiffusionXLPipeline
from gradio_demo.style_template import styles
MAX_SEED = np.iinfo(np.int32).max
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "Photographic (Default)"
FEATURE_EXTRACTOR = "./feature-extractor"
SAFETY_CACHE = "./models/safety-cache"
SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar"
from photomaker.pipeline import PhotoMakerStableDiffusionXLPipeline BASE_MODEL_URL = "https://weights.replicate.delivery/default/SG161222--RealVisXL_V3.0-11ee564ebf4bd96d90ed5d473cb8e7f2e6450bcf.tar"
BASE_MODEL_PATH = "models/SG161222/RealVisXL_V3.0"
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") PHOTOMAKER_URL = "https://weights.replicate.delivery/default/TencentARC--PhotoMaker/photomaker-v1.bin"
logger = logging.getLogger(__name__) PHOTOMAKER_PATH = "models/photomaker-v1.bin"
base_model_path = 'SG161222/RealVisXL_V3.0' def download_weights(url, dest, extract=True):
photomaker_path = 'release_model/photomaker-v1.bin' start = time.time()
device = "cuda" print("downloading url: ", url)
print("downloading to: ", dest)
args = ["pget"]
if extract:
args.append("-x")
subprocess.check_call(args + [url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
# utility function for style templates
def apply_style(style_name: str, positive: str, negative: str = "") -> tuple[str, str]:
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
return p.replace("{prompt}", positive), n + " " + negative
class Predictor(BasePredictor): class Predictor(BasePredictor):
def setup(self) -> None: def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient""" """Load the model into memory to make running multiple predictions efficient"""
start = time.time()
logger.info("Loading model...") self.device = "cuda" if torch.cuda.is_available() else "cpu"
# download PhotoMaker checkpoint to cache
# if we already have the model, this doesn't do anything
if not os.path.exists(PHOTOMAKER_PATH):
download_weights(PHOTOMAKER_URL, PHOTOMAKER_PATH, extract=False)
if not os.path.exists(BASE_MODEL_PATH):
download_weights(BASE_MODEL_URL, BASE_MODEL_PATH)
print("Loading safety checker...")
if not os.path.exists(SAFETY_CACHE):
download_weights(SAFETY_URL, SAFETY_CACHE)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
self.pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained( self.pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
base_model_path, BASE_MODEL_PATH,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
use_safetensors=True, use_safetensors=True,
variant="fp16" variant="fp16",
).to(device) ).to(self.device)
self.pipe.load_photomaker_adapter( self.pipe.load_photomaker_adapter(
os.path.dirname(photomaker_path), os.path.dirname(PHOTOMAKER_PATH),
subfolder="", subfolder="",
weight_name=os.path.basename(photomaker_path), weight_name=os.path.basename(PHOTOMAKER_PATH),
trigger_word="img" trigger_word="img",
) )
self.pipe.id_encoder.to(self.device)
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config
)
self.pipe.fuse_lora() self.pipe.fuse_lora()
logger.info(f"Loaded model in {time.time() - start:.06}s")
def _load_image(self, path):
shutil.copyfile(path, "/tmp/image.png")
return load_image("/tmp/image.png").convert("RGB")
@torch.inference_mode() @torch.inference_mode()
def predict( def predict(
self, self,
input_image: Path = Input(
description="The input image, for example a photo of your face."
),
input_image2: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image3: Path = Input(
description="Additional input image (optional)",
default=None
),
input_image4: Path = Input(
description="Additional input image (optional)",
default=None
),
prompt: str = Input( prompt: str = Input(
description="Input prompt", description="Prompt. Example: 'a photo of a man/woman img'. The phrase 'img' is the trigger word.",
default="sci-fi, closeup portrait photo of a man img wearing the sunglasses in Iron man suit, face, slim body, high quality, film grain" default="A photo of a person img",
),
style_name: str = Input(
description="Style template. The style template will add a style-specific prompt and negative prompt to the user's prompt.",
choices=STYLE_NAMES,
default=DEFAULT_STYLE_NAME,
), ),
negative_prompt: str = Input( negative_prompt: str = Input(
description="Negative Input prompt", description="Negative Prompt. The negative prompt should NOT contain the trigger word.",
default="(asymmetry, worst quality, low quality, illustration, 3d, 2d, painting, cartoons, sketch), open mouth" default="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
), ),
image: Path = Input( num_steps: int = Input(
description="Input image for img2img or inpaint mode", description="Number of sample steps", default=20, ge=1, le=100
default=None,
), ),
seed: int = Input( style_strength_ratio: float = Input(
description="Random seed. Leave blank to randomize the seed", default=None description="Style strength (%)", default=20, ge=15, le=50
), ),
num_outputs: int = Input( num_outputs: int = Input(
description="Number of images to output.", description="Number of output images", default=1, ge=1, le=4
ge=1, ),
le=4, guidance_scale: float = Input(
default=1, description="Guidance scale. A guidance scale of 1 corresponds to doing no classifier free guidance.", default=5, ge=1, le=10.0
), ),
num_inference_steps: int = Input( seed: int = Input(description="Seed. Leave blank to use a random number", default=None, ge=0, le=MAX_SEED),
description="Number of denoising steps", ge=1, le=500, default=40 disable_safety_checker: bool = Input(
description="Disable safety checker for generated images.",
default=False
) )
) -> List[Path]: ) -> list[Path]:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
# remove old outputs
output_folder = Path('outputs')
if output_folder.exists():
shutil.rmtree(output_folder)
os.makedirs(str(output_folder), exist_ok=False)
# randomize seed if necessary
if seed is None: if seed is None:
seed = int.from_bytes(os.urandom(4), "big") seed = random.randint(0, MAX_SEED)
logger.info(f"Using seed: {seed}") print(f"Using seed {seed}...")
generator = torch.Generator("cuda").manual_seed(seed)
# check the prompt for the trigger word
image_token_id = self.pipe.tokenizer.convert_tokens_to_ids(self.pipe.trigger_word)
input_ids = self.pipe.tokenizer.encode(prompt)
if image_token_id not in input_ids:
raise ValueError(
f"Cannot find the trigger word '{self.pipe.trigger_word}' in text prompt!")
style_strength_ratio = 20 if input_ids.count(image_token_id) > 1:
start_merge_step = int(float(style_strength_ratio) / 100 * num_inference_steps) raise ValueError(
f"Cannot use multiple trigger words '{self.pipe.trigger_word}' in text prompt!"
)
# check the negative prompt for the trigger word
if negative_prompt:
negative_prompt_ids = self.pipe.tokenizer.encode(negative_prompt)
if image_token_id in negative_prompt_ids:
raise ValueError(
f"Cannot use trigger word '{self.pipe.trigger_word}' in negative prompt!"
)
# apply the style template
prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
# load the input images
input_id_images = []
for maybe_image in [input_image, input_image2, input_image3, input_image4]:
if maybe_image:
print(f"Loading image {maybe_image}...")
input_id_images.append(load_image(str(maybe_image)))
print(f"Setting seed...")
generator = torch.Generator(device=self.device).manual_seed(seed)
print("Start inference...")
print(f"[Debug] Prompt: {prompt}")
print(f"[Debug] Neg Prompt: {negative_prompt}")
start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
if start_merge_step > 30: if start_merge_step > 30:
start_merge_step = 30 start_merge_step = 30
print(f"Start merge step: {start_merge_step}")
images = self.pipe( images = self.pipe(
prompt=prompt, prompt=prompt,
input_id_images=[self._load_image(image)], input_id_images=input_id_images,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
num_images_per_prompt=num_outputs, num_images_per_prompt=num_outputs,
num_inference_steps=num_inference_steps, num_inference_steps=num_steps,
start_merge_step=start_merge_step, start_merge_step=start_merge_step,
generator=generator, generator=generator,
guidance_scale=guidance_scale,
).images ).images
if not disable_safety_checker:
print(f"Running safety checker...")
_, has_nsfw_content = self.run_safety_checker(images)
# save results to file
print(f"Saving images to file...")
output_paths = [] output_paths = []
for i, image in enumerate(images): for i, image in enumerate(images):
output_path = f"/tmp/out-{i}.png" if not disable_safety_checker:
if has_nsfw_content[i]:
print(f"NSFW content detected in image {i}")
continue
output_path = output_folder / f"image_{i}.png"
image.save(output_path) image.save(output_path)
output_paths.append(Path(output_path)) output_paths.append(output_path)
return [Path(p) for p in output_paths]
return output_paths def run_safety_checker(self, image):
\ No newline at end of file safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
"cuda"
)
np_image = [np.array(val) for val in image]
image, has_nsfw_concept = self.safety_checker(
images=np_image,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
return image, has_nsfw_concept
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment