Unverified Commit 46f4251a authored by K's avatar K Committed by GitHub
Browse files

fix: upgrade pulid to 0.9.1 (#408)

* upgrade pulid to 0.9.1,When the EVA CLIP config is not found, I use the built-in JSON parameters instead.The issue of repeated model downloads has been resolved, and the model path specified by ComfyUI will be faithfully used.

style:apply black check

style:apply isort check

* fix callback bugs

style: apply clang-format
parent e302ec32
......@@ -42,7 +42,6 @@ public:
if (net) {
pybind11::object cb = residual_callback;
net->set_residual_callback([cb](const Tensor &x) -> Tensor {
pybind11::gil_scoped_acquire gil;
torch::Tensor torch_x = to_torch(x);
pybind11::object result = cb(torch_x);
torch::Tensor torch_y = result.cast<torch::Tensor>();
......
......@@ -239,8 +239,35 @@ def create_model(
if model_cfg is not None:
logging.info(f"Loaded {model_name} model config.")
else:
logging.error(f"Model config for {model_name} not found; available models {list_models()}.")
raise RuntimeError(f"Model config for {model_name} not found.")
model_cfg = {
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"drop_path_rate": 0,
"head_width": 64,
"mlp_ratio": 2.6667,
"patch_size": 14,
"eva_model_name": "eva-clip-l-14-336",
"xattn": True,
"fusedLN": True,
"rope": True,
"pt_hw_seq_len": 16,
"intp_freq": True,
"naiveswiglu": True,
"subln": True,
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"xattn": False,
"fusedLN": True,
},
}
if "rope" in model_cfg.get("vision_cfg", {}):
if model_cfg["vision_cfg"]["rope"]:
......
......@@ -201,7 +201,7 @@ class NunchakuFluxTransformerBlocks(nn.Module):
id_weight = self.id_weight
def callback(hidden_states):
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda"))
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states)
pulid_ca_idx[0] += 1
return ip
......
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import gc
import os
from typing import Any, Callable, Dict, List, Optional, Union
import cv2
......@@ -27,10 +28,13 @@ from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img
class PuLIDPipeline(nn.Module):
def __init__(self, dit, device, weight_dtype=torch.bfloat16, onnx_provider="gpu", *args, **kwargs):
def __init__(
self, dit, device, weight_dtype=torch.bfloat16, onnx_provider="gpu", folder_path="models", *args, **kwargs
):
super().__init__()
self.device = device
self.weight_dtype = weight_dtype
self.folder_path = folder_path
double_interval = 2
single_interval = 4
......@@ -73,13 +77,16 @@ class PuLIDPipeline(nn.Module):
self.eva_transform_mean = eva_transform_mean
self.eva_transform_std = eva_transform_std
# antelopev2
snapshot_download("DIAMONIK7777/antelopev2", local_dir="models/antelopev2")
antelopev2_path = os.path.join(folder_path, "insightface", "models", "antelopev2")
snapshot_download("DIAMONIK7777/antelopev2", local_dir=antelopev2_path)
providers = (
["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.app = FaceAnalysis(name="antelopev2", root=".", providers=providers)
self.app = FaceAnalysis(name="antelopev2", root=os.path.join(folder_path, "insightface"), providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640))
self.handler_ante = insightface.model_zoo.get_model("models/antelopev2/glintr100.onnx", providers=providers)
self.handler_ante = insightface.model_zoo.get_model(
os.path.join(antelopev2_path, "glintr100.onnx"), providers=providers
)
self.handler_ante.prepare(ctx_id=0)
gc.collect()
......@@ -88,9 +95,11 @@ class PuLIDPipeline(nn.Module):
# other configs
self.debug_img_list = []
def load_pretrain(self, pretrain_path=None, version="v0.9.0"):
hf_hub_download("guozinan/PuLID", f"pulid_flux_{version}.safetensors", local_dir="models")
ckpt_path = f"models/pulid_flux_{version}.safetensors"
def load_pretrain(self, pretrain_path=None, version="v0.9.1"):
hf_hub_download(
"guozinan/PuLID", f"pulid_flux_{version}.safetensors", local_dir=os.path.join(self.folder_path, "pulid")
)
ckpt_path = os.path.join(self.folder_path, "pulid", f"pulid_flux_{version}.safetensors")
if pretrain_path is not None:
ckpt_path = pretrain_path
state_dict = load_file(ckpt_path)
......@@ -207,6 +216,7 @@ class PuLIDFluxPipeline(FluxPipeline):
weight_dtype=torch.bfloat16,
onnx_provider="gpu",
pretrained_model=None,
folder_path="models",
):
super().__init__(
scheduler=scheduler,
......@@ -231,6 +241,7 @@ class PuLIDFluxPipeline(FluxPipeline):
device=self.pulid_device,
weight_dtype=self.weight_dtype,
onnx_provider=self.onnx_provider,
folder_path=folder_path,
)
self.pulid_model.load_pretrain(pretrained_model)
......
......@@ -837,11 +837,8 @@ Tensor FluxModel::forward(Tensor hidden_states,
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
if (residual_callback && layer % 2 == 0) {
Tensor cpu_input = hidden_states.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
hidden_states = kernels::add(hidden_states, residual);
Tensor residual = residual_callback(hidden_states);
hidden_states = kernels::add(hidden_states, residual);
}
} else {
if (size_t(layer) == transformer_blocks.size()) {
......@@ -875,12 +872,9 @@ Tensor FluxModel::forward(Tensor hidden_states,
size_t local_layer_idx = layer - transformer_blocks.size();
if (residual_callback && local_layer_idx % 4 == 0) {
Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Tensor cpu_input = callback_input.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, residual);
Tensor residual = residual_callback(callback_input);
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, residual);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment