Unverified Commit d49a4a70 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: custom model path of PuLID (#465)

* change the model path

* change the model path

* fix the model paths

* rename facexlib_path to facexlib_dirpath

* add start_timestep and end_timestep

* no need to download the files if the folder exists
parent 8fbf418d
......@@ -3,11 +3,13 @@ import logging
import os
import re
from copy import deepcopy
from os import PathLike
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
from ....utils import fetch_or_download
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
......@@ -227,6 +229,7 @@ def create_model(
pretrained_text_model: str = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
pretrained_path: str | PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
):
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
if isinstance(device, str):
......@@ -297,12 +300,7 @@ def create_model(
pretrained_cfg = {}
if pretrained:
checkpoint_path = ""
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
checkpoint_path = fetch_or_download(pretrained_path)
if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
......@@ -406,6 +404,7 @@ def create_model_and_transforms(
image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None,
skip_list: list = [],
pretrained_path: str | PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
):
model = create_model(
model_name,
......@@ -423,6 +422,7 @@ def create_model_and_transforms(
pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir,
skip_list=skip_list,
pretrained_path=pretrained_path,
)
image_mean = image_mean or getattr(model.visual, "image_mean", None)
......
......@@ -24,6 +24,8 @@ def pulid_forward(
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
start_timestep: float | None = None,
end_timestep: float | None = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
......@@ -53,6 +55,16 @@ def pulid_forward(
"""
hidden_states = self.x_embedder(hidden_states)
if timestep.numel() > 1:
timestep_float = timestep.flatten()[0].item()
else:
timestep_float = timestep.item()
if start_timestep is not None and start_timestep > timestep_float:
id_embeddings = None
if end_timestep is not None and end_timestep < timestep_float:
id_embeddings = None
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
......
......@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.id_weight = id_weight
self.pulid_ca_idx = 0
if self.id_embeddings is not None:
self.set_residual_callback()
self.set_pulid_residual_callback()
original_dtype = hidden_states.dtype
original_device = hidden_states.device
......@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
)
if self.id_embeddings is not None:
self.reset_residual_callback()
self.reset_pulid_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device)
......@@ -194,7 +194,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return encoder_hidden_states, hidden_states
def set_residual_callback(self):
def set_pulid_residual_callback(self):
id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx]
......@@ -208,7 +208,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.callback_holder = callback
self.m.set_residual_callback(callback)
def reset_residual_callback(self):
def reset_pulid_residual_callback(self):
self.callback_holder = None
self.m.set_residual_callback(None)
......
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import gc
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import cv2
......@@ -14,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import replace_example_docstring
from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub import snapshot_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from insightface.app import FaceAnalysis
from safetensors.torch import load_file
from torch import nn
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize
......@@ -25,16 +27,57 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from ..models.pulid.eva_clip import create_model_and_transforms
from ..models.pulid.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img
from ..models.transformers import NunchakuFluxTransformer2dModel
from ..utils import load_state_dict_in_safetensors, sha256sum
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
antelopev2_dirpath = Path(antelopev2_dirpath)
required_files = {
"1k3d68.onnx": "df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc",
"2d106det.onnx": "f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf",
"genderage.onnx": "4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb",
"glintr100.onnx": "4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf",
"scrfd_10g_bnkps.onnx": "5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91",
}
if not antelopev2_dirpath.is_dir():
logger.debug(f"Directory does not exist: {antelopev2_dirpath}")
return False
for filename, expected_hash in required_files.items():
filepath = antelopev2_dirpath / filename
if not filepath.exists():
logger.debug(f"Missing file: {filename}")
return False
if expected_hash != "<SKIP_HASH>" and not sha256sum(filepath) == expected_hash:
logger.debug(f"Hash mismatch for: {filename}")
return False
return True
class PuLIDPipeline(nn.Module):
def __init__(
self, dit, device, weight_dtype=torch.bfloat16, onnx_provider="gpu", folder_path="models", *args, **kwargs
self,
dit: NunchakuFluxTransformer2dModel,
device: str | torch.device,
weight_dtype: str | torch.dtype = torch.bfloat16,
onnx_provider: str = "gpu",
pulid_path: str | os.PathLike[str] = "guozinan/PuLID/pulid_flux_v0.9.1.safetensors",
eva_clip_path: str | os.PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
insightface_dirpath: str | os.PathLike[str] | None = None,
facexlib_dirpath: str | os.PathLike[str] | None = None,
):
super().__init__()
self.device = device
self.weight_dtype = weight_dtype
self.folder_path = folder_path
double_interval = 2
single_interval = 4
......@@ -54,6 +97,11 @@ class PuLIDPipeline(nn.Module):
# preprocessors
# face align and parsing
if facexlib_dirpath is None:
facexlib_dirpath = Path(HUGGINGFACE_HUB_CACHE) / "facexlib"
facexlib_dirpath = Path(facexlib_dirpath)
self.face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
......@@ -61,11 +109,17 @@ class PuLIDPipeline(nn.Module):
det_model="retinaface_resnet50",
save_ext="png",
device=self.device,
model_rootpath=str(facexlib_dirpath),
)
self.face_helper.face_parse = None
self.face_helper.face_parse = init_parsing_model(model_name="bisenet", device=self.device)
self.face_helper.face_parse = init_parsing_model(
model_name="bisenet", device=self.device, model_rootpath=str(facexlib_dirpath)
)
# clip-vit backbone
model, _, _ = create_model_and_transforms("EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True)
model, _, _ = create_model_and_transforms(
"EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True, pretrained_path=eva_clip_path
)
model = model.visual
self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype)
eva_transform_mean = getattr(self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN)
......@@ -76,46 +130,51 @@ class PuLIDPipeline(nn.Module):
eva_transform_std = (eva_transform_std,) * 3
self.eva_transform_mean = eva_transform_mean
self.eva_transform_std = eva_transform_std
# antelopev2
antelopev2_path = os.path.join(folder_path, "insightface", "models", "antelopev2")
snapshot_download("DIAMONIK7777/antelopev2", local_dir=antelopev2_path)
if insightface_dirpath is None:
insightface_dirpath = Path(HUGGINGFACE_HUB_CACHE) / "insightface"
insightface_dirpath = Path(insightface_dirpath)
if insightface_dirpath is not None:
antelopev2_dirpath = insightface_dirpath / "models" / "antelopev2"
else:
antelopev2_dirpath = None
if antelopev2_dirpath is None or not check_antelopev2_dir(antelopev2_dirpath):
snapshot_download("DIAMONIK7777/antelopev2", local_dir=antelopev2_dirpath)
providers = (
["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.app = FaceAnalysis(name="antelopev2", root=os.path.join(folder_path, "insightface"), providers=providers)
self.app = FaceAnalysis(name="antelopev2", root=insightface_dirpath, providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640))
self.handler_ante = insightface.model_zoo.get_model(
os.path.join(antelopev2_path, "glintr100.onnx"), providers=providers
str(antelopev2_dirpath / "glintr100.onnx"), providers=providers
)
self.handler_ante.prepare(ctx_id=0)
gc.collect()
torch.cuda.empty_cache()
# other configs
self.debug_img_list = []
# pulid model
state_dict = load_state_dict_in_safetensors(pulid_path)
module_state_dict = {}
def load_pretrain(self, pretrain_path=None, version="v0.9.1"):
hf_hub_download(
"guozinan/PuLID", f"pulid_flux_{version}.safetensors", local_dir=os.path.join(self.folder_path, "pulid")
)
ckpt_path = os.path.join(self.folder_path, "pulid", f"pulid_flux_{version}.safetensors")
if pretrain_path is not None:
ckpt_path = pretrain_path
state_dict = load_file(ckpt_path)
state_dict_dict = {}
for k, v in state_dict.items():
module = k.split(".")[0]
state_dict_dict.setdefault(module, {})
module_state_dict.setdefault(module, {})
new_k = k[len(module) + 1 :]
state_dict_dict[module][new_k] = v
module_state_dict[module][new_k] = v
for module in state_dict_dict:
print(f"loading from {module}")
getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)
for module in module_state_dict:
logging.debug(f"loading from {module}")
getattr(self, module).load_state_dict(module_state_dict[module], strict=True)
del state_dict
del state_dict_dict
del module_state_dict
gc.collect()
torch.cuda.empty_cache()
# other configs
self.debug_img_list = []
def to_gray(self, img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
......@@ -215,8 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline):
pulid_device="cuda",
weight_dtype=torch.bfloat16,
onnx_provider="gpu",
pretrained_model=None,
folder_path="models",
):
super().__init__(
scheduler=scheduler,
......@@ -241,9 +298,7 @@ class PuLIDFluxPipeline(FluxPipeline):
device=self.pulid_device,
weight_dtype=self.weight_dtype,
onnx_provider=self.onnx_provider,
folder_path=folder_path,
)
self.pulid_model.load_pretrain(pretrained_model)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
......
import hashlib
import os
import warnings
from pathlib import Path
......@@ -7,6 +8,14 @@ import torch
from huggingface_hub import hf_hub_download
def sha256sum(filepath: str | os.PathLike[str]) -> str:
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
path = Path(path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment