Unverified Commit d49a4a70 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

feat: custom model path of PuLID (#465)

* change the model path

* change the model path

* fix the model paths

* rename facexlib_path to facexlib_dirpath

* add start_timestep and end_timestep

* no need to download the files if the folder exists
parent 8fbf418d
...@@ -3,11 +3,13 @@ import logging ...@@ -3,11 +3,13 @@ import logging
import os import os
import re import re
from copy import deepcopy from copy import deepcopy
from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from ....utils import fetch_or_download
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model from .pretrained import download_pretrained, get_pretrained_cfg, list_pretrained_tags_by_model
...@@ -227,6 +229,7 @@ def create_model( ...@@ -227,6 +229,7 @@ def create_model(
pretrained_text_model: str = None, pretrained_text_model: str = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
skip_list: list = [], skip_list: list = [],
pretrained_path: str | PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
): ):
model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names model_name = model_name.replace("/", "-") # for callers using old naming with / in ViT names
if isinstance(device, str): if isinstance(device, str):
...@@ -297,12 +300,7 @@ def create_model( ...@@ -297,12 +300,7 @@ def create_model(
pretrained_cfg = {} pretrained_cfg = {}
if pretrained: if pretrained:
checkpoint_path = "" checkpoint_path = fetch_or_download(pretrained_path)
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path: if checkpoint_path:
logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") logging.info(f"Loading pretrained {model_name} weights ({pretrained}).")
...@@ -406,6 +404,7 @@ def create_model_and_transforms( ...@@ -406,6 +404,7 @@ def create_model_and_transforms(
image_std: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
skip_list: list = [], skip_list: list = [],
pretrained_path: str | PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
): ):
model = create_model( model = create_model(
model_name, model_name,
...@@ -423,6 +422,7 @@ def create_model_and_transforms( ...@@ -423,6 +422,7 @@ def create_model_and_transforms(
pretrained_text_model=pretrained_text_model, pretrained_text_model=pretrained_text_model,
cache_dir=cache_dir, cache_dir=cache_dir,
skip_list=skip_list, skip_list=skip_list,
pretrained_path=pretrained_path,
) )
image_mean = image_mean or getattr(model.visual, "image_mean", None) image_mean = image_mean or getattr(model.visual, "image_mean", None)
......
...@@ -24,6 +24,8 @@ def pulid_forward( ...@@ -24,6 +24,8 @@ def pulid_forward(
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
return_dict: bool = True, return_dict: bool = True,
controlnet_blocks_repeat: bool = False, controlnet_blocks_repeat: bool = False,
start_timestep: float | None = None,
end_timestep: float | None = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
""" """
Copied from diffusers.models.flux.transformer_flux.py Copied from diffusers.models.flux.transformer_flux.py
...@@ -53,6 +55,16 @@ def pulid_forward( ...@@ -53,6 +55,16 @@ def pulid_forward(
""" """
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
if timestep.numel() > 1:
timestep_float = timestep.flatten()[0].item()
else:
timestep_float = timestep.item()
if start_timestep is not None and start_timestep > timestep_float:
id_embeddings = None
if end_timestep is not None and end_timestep < timestep_float:
id_embeddings = None
timestep = timestep.to(hidden_states.dtype) * 1000 timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None: if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000 guidance = guidance.to(hidden_states.dtype) * 1000
......
...@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -81,7 +81,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.id_weight = id_weight self.id_weight = id_weight
self.pulid_ca_idx = 0 self.pulid_ca_idx = 0
if self.id_embeddings is not None: if self.id_embeddings is not None:
self.set_residual_callback() self.set_pulid_residual_callback()
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
original_device = hidden_states.device original_device = hidden_states.device
...@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -129,7 +129,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
) )
if self.id_embeddings is not None: if self.id_embeddings is not None:
self.reset_residual_callback() self.reset_pulid_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
...@@ -194,7 +194,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -194,7 +194,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def set_residual_callback(self): def set_pulid_residual_callback(self):
id_embeddings = self.id_embeddings id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx] pulid_ca_idx = [self.pulid_ca_idx]
...@@ -208,7 +208,7 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -208,7 +208,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.callback_holder = callback self.callback_holder = callback
self.m.set_residual_callback(callback) self.m.set_residual_callback(callback)
def reset_residual_callback(self): def reset_pulid_residual_callback(self):
self.callback_holder = None self.callback_holder = None
self.m.set_residual_callback(None) self.m.set_residual_callback(None)
......
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py # Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import gc import gc
import logging
import os import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import cv2 import cv2
...@@ -14,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput ...@@ -14,9 +16,9 @@ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import replace_example_docstring from diffusers.utils import replace_example_docstring
from facexlib.parsing import init_parsing_model from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import snapshot_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from insightface.app import FaceAnalysis from insightface.app import FaceAnalysis
from safetensors.torch import load_file
from torch import nn from torch import nn
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize from torchvision.transforms.functional import normalize, resize
...@@ -25,16 +27,57 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA ...@@ -25,16 +27,57 @@ from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from ..models.pulid.eva_clip import create_model_and_transforms from ..models.pulid.eva_clip import create_model_and_transforms
from ..models.pulid.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from ..models.pulid.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img
from ..models.transformers import NunchakuFluxTransformer2dModel
from ..utils import load_state_dict_in_safetensors, sha256sum
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
antelopev2_dirpath = Path(antelopev2_dirpath)
required_files = {
"1k3d68.onnx": "df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc",
"2d106det.onnx": "f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf",
"genderage.onnx": "4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb",
"glintr100.onnx": "4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf",
"scrfd_10g_bnkps.onnx": "5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91",
}
if not antelopev2_dirpath.is_dir():
logger.debug(f"Directory does not exist: {antelopev2_dirpath}")
return False
for filename, expected_hash in required_files.items():
filepath = antelopev2_dirpath / filename
if not filepath.exists():
logger.debug(f"Missing file: {filename}")
return False
if expected_hash != "<SKIP_HASH>" and not sha256sum(filepath) == expected_hash:
logger.debug(f"Hash mismatch for: {filename}")
return False
return True
class PuLIDPipeline(nn.Module): class PuLIDPipeline(nn.Module):
def __init__( def __init__(
self, dit, device, weight_dtype=torch.bfloat16, onnx_provider="gpu", folder_path="models", *args, **kwargs self,
dit: NunchakuFluxTransformer2dModel,
device: str | torch.device,
weight_dtype: str | torch.dtype = torch.bfloat16,
onnx_provider: str = "gpu",
pulid_path: str | os.PathLike[str] = "guozinan/PuLID/pulid_flux_v0.9.1.safetensors",
eva_clip_path: str | os.PathLike[str] = "QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt",
insightface_dirpath: str | os.PathLike[str] | None = None,
facexlib_dirpath: str | os.PathLike[str] | None = None,
): ):
super().__init__() super().__init__()
self.device = device self.device = device
self.weight_dtype = weight_dtype self.weight_dtype = weight_dtype
self.folder_path = folder_path
double_interval = 2 double_interval = 2
single_interval = 4 single_interval = 4
...@@ -54,6 +97,11 @@ class PuLIDPipeline(nn.Module): ...@@ -54,6 +97,11 @@ class PuLIDPipeline(nn.Module):
# preprocessors # preprocessors
# face align and parsing # face align and parsing
if facexlib_dirpath is None:
facexlib_dirpath = Path(HUGGINGFACE_HUB_CACHE) / "facexlib"
facexlib_dirpath = Path(facexlib_dirpath)
self.face_helper = FaceRestoreHelper( self.face_helper = FaceRestoreHelper(
upscale_factor=1, upscale_factor=1,
face_size=512, face_size=512,
...@@ -61,11 +109,17 @@ class PuLIDPipeline(nn.Module): ...@@ -61,11 +109,17 @@ class PuLIDPipeline(nn.Module):
det_model="retinaface_resnet50", det_model="retinaface_resnet50",
save_ext="png", save_ext="png",
device=self.device, device=self.device,
model_rootpath=str(facexlib_dirpath),
) )
self.face_helper.face_parse = None self.face_helper.face_parse = None
self.face_helper.face_parse = init_parsing_model(model_name="bisenet", device=self.device) self.face_helper.face_parse = init_parsing_model(
model_name="bisenet", device=self.device, model_rootpath=str(facexlib_dirpath)
)
# clip-vit backbone # clip-vit backbone
model, _, _ = create_model_and_transforms("EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True) model, _, _ = create_model_and_transforms(
"EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True, pretrained_path=eva_clip_path
)
model = model.visual model = model.visual
self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype) self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype)
eva_transform_mean = getattr(self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN) eva_transform_mean = getattr(self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN)
...@@ -76,46 +130,51 @@ class PuLIDPipeline(nn.Module): ...@@ -76,46 +130,51 @@ class PuLIDPipeline(nn.Module):
eva_transform_std = (eva_transform_std,) * 3 eva_transform_std = (eva_transform_std,) * 3
self.eva_transform_mean = eva_transform_mean self.eva_transform_mean = eva_transform_mean
self.eva_transform_std = eva_transform_std self.eva_transform_std = eva_transform_std
# antelopev2 # antelopev2
antelopev2_path = os.path.join(folder_path, "insightface", "models", "antelopev2") if insightface_dirpath is None:
snapshot_download("DIAMONIK7777/antelopev2", local_dir=antelopev2_path) insightface_dirpath = Path(HUGGINGFACE_HUB_CACHE) / "insightface"
insightface_dirpath = Path(insightface_dirpath)
if insightface_dirpath is not None:
antelopev2_dirpath = insightface_dirpath / "models" / "antelopev2"
else:
antelopev2_dirpath = None
if antelopev2_dirpath is None or not check_antelopev2_dir(antelopev2_dirpath):
snapshot_download("DIAMONIK7777/antelopev2", local_dir=antelopev2_dirpath)
providers = ( providers = (
["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"] ["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"]
) )
self.app = FaceAnalysis(name="antelopev2", root=os.path.join(folder_path, "insightface"), providers=providers) self.app = FaceAnalysis(name="antelopev2", root=insightface_dirpath, providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640)) self.app.prepare(ctx_id=0, det_size=(640, 640))
self.handler_ante = insightface.model_zoo.get_model( self.handler_ante = insightface.model_zoo.get_model(
os.path.join(antelopev2_path, "glintr100.onnx"), providers=providers str(antelopev2_dirpath / "glintr100.onnx"), providers=providers
) )
self.handler_ante.prepare(ctx_id=0) self.handler_ante.prepare(ctx_id=0)
gc.collect() # pulid model
torch.cuda.empty_cache() state_dict = load_state_dict_in_safetensors(pulid_path)
module_state_dict = {}
# other configs
self.debug_img_list = []
def load_pretrain(self, pretrain_path=None, version="v0.9.1"):
hf_hub_download(
"guozinan/PuLID", f"pulid_flux_{version}.safetensors", local_dir=os.path.join(self.folder_path, "pulid")
)
ckpt_path = os.path.join(self.folder_path, "pulid", f"pulid_flux_{version}.safetensors")
if pretrain_path is not None:
ckpt_path = pretrain_path
state_dict = load_file(ckpt_path)
state_dict_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
module = k.split(".")[0] module = k.split(".")[0]
state_dict_dict.setdefault(module, {}) module_state_dict.setdefault(module, {})
new_k = k[len(module) + 1 :] new_k = k[len(module) + 1 :]
state_dict_dict[module][new_k] = v module_state_dict[module][new_k] = v
for module in state_dict_dict: for module in module_state_dict:
print(f"loading from {module}") logging.debug(f"loading from {module}")
getattr(self, module).load_state_dict(state_dict_dict[module], strict=True) getattr(self, module).load_state_dict(module_state_dict[module], strict=True)
del state_dict del state_dict
del state_dict_dict del module_state_dict
gc.collect()
torch.cuda.empty_cache()
# other configs
self.debug_img_list = []
def to_gray(self, img): def to_gray(self, img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
...@@ -215,8 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline): ...@@ -215,8 +274,6 @@ class PuLIDFluxPipeline(FluxPipeline):
pulid_device="cuda", pulid_device="cuda",
weight_dtype=torch.bfloat16, weight_dtype=torch.bfloat16,
onnx_provider="gpu", onnx_provider="gpu",
pretrained_model=None,
folder_path="models",
): ):
super().__init__( super().__init__(
scheduler=scheduler, scheduler=scheduler,
...@@ -241,9 +298,7 @@ class PuLIDFluxPipeline(FluxPipeline): ...@@ -241,9 +298,7 @@ class PuLIDFluxPipeline(FluxPipeline):
device=self.pulid_device, device=self.pulid_device,
weight_dtype=self.weight_dtype, weight_dtype=self.weight_dtype,
onnx_provider=self.onnx_provider, onnx_provider=self.onnx_provider,
folder_path=folder_path,
) )
self.pulid_model.load_pretrain(pretrained_model)
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
......
import hashlib
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
...@@ -7,6 +8,14 @@ import torch ...@@ -7,6 +8,14 @@ import torch
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
def sha256sum(filepath: str | os.PathLike[str]) -> str:
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path: def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
path = Path(path) path = Path(path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment