Unverified Commit b737368d authored by K's avatar K Committed by GitHub
Browse files

feat: PuLID support (#274)



* add pulid

* Add the feature that allows the mixed use of pulid and non-pulid after loading pulid to generate the pipeline.

* Added the feature to load LoRA at any time.

* Organized the directory structure.

* Organized the code.

* Removed unused related code from eva-clip.

* style: apply Ruff formatting

* Refactored code and verified pulid works.

* add pulid tests

* auto detect precision in test

* Updated requirements.txt

* update requirements

* style: reformat the example

* style: reformat the example

* style: rename cb to call_back

* style: format the codes

* style: format the codes

* reformated the codes

* fix the repo forward

* clean some dead codes

* wrap up for pulid

---------
Co-authored-by: default avatarkkkxue <kkkxue@tencent.com>
Co-authored-by: default avatarmuyangli <lmxyy1999@foxmail.com>
parent b4d3f50b
# Adapted from https://github.com/ToTheBeginning/PuLID
import math
import cv2
import numpy as np
import torch
from torchvision.utils import make_grid
def resize_numpy_image_long(image, resize_long_edge=768):
h, w = image.shape[:2]
if max(h, w) <= resize_long_edge:
return image
k = resize_long_edge / max(h, w)
h = int(h * k)
w = int(w * k)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
return image
# from basicsr
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == "float64":
img = img.astype("float32")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f"Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}")
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
......@@ -63,6 +63,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
id_embeddings=None,
id_weight=None,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
......@@ -72,6 +74,12 @@ class NunchakuFluxTransformerBlocks(nn.Module):
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
self.id_embeddings = id_embeddings
self.id_weight = id_weight
self.pulid_ca_idx = 0
if self.id_embeddings is not None :
self.set_residual_callback()
original_dtype = hidden_states.dtype
original_device = hidden_states.device
......@@ -114,9 +122,13 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer,
skip_first_layer
)
if self.id_embeddings is not None :
self.reset_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
......@@ -179,7 +191,20 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
return encoder_hidden_states, hidden_states
def set_residual_callback(self):
id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx]
id_weight = self.id_weight
def callback(hidden_states):
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda"))
pulid_ca_idx[0] += 1
return ip
self.callback_holder = callback
self.m.set_residual_callback(callback)
def reset_residual_callback(self):
self.callback_holder = None
self.m.set_residual_callback(None)
def __del__(self):
self.m.reset()
......@@ -451,6 +476,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
self._unquantized_part_loras = unquantized_part_loras
self._unquantized_part_sd = {
k: v for k, v in self._unquantized_part_sd.items()
if "pulid_ca" not in k
}
self._update_unquantized_part_lora_params(1)
quantized_part_vectors = {}
......
from .pipeline_flux_pulid import PuLIDFluxPipeline
__all__ = ["PuLIDFluxPipeline"]
This diff is collapsed.
# pulid related
insightface
opencv-python
facexlib
onnxruntime
\ No newline at end of file
......@@ -4,9 +4,10 @@
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "activation.h"
#include <nvtx3/nvToolsExt.h>
#include <pybind11/functional.h>
#include <iostream>
using spdlog::fmt_lib::format;
......@@ -819,6 +820,13 @@ Tensor FluxModel::forward(
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
if (residual_callback && layer % 2 == 0) {
Tensor cpu_input = hidden_states.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
hidden_states = kernels::add(hidden_states, residual);
}
} else {
if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers
......@@ -845,6 +853,17 @@ Tensor FluxModel::forward(
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
size_t local_layer_idx = layer - transformer_blocks.size();
if (residual_callback && local_layer_idx % 4 == 0) {
Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Tensor cpu_input = callback_input.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, residual);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
}
};
......@@ -936,3 +955,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) {
block->attnImpl = impl;
}
}
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor&)> cb) {
residual_callback = std::move(cb);
}
\ No newline at end of file
......@@ -5,6 +5,10 @@
#include "Module.h"
#include "Linear.h"
#include "layernorm.h"
#include <pybind11/functional.h>
namespace pybind11 {
class function;
}
enum class AttentionImpl {
FlashAttention2 = 0,
......@@ -160,12 +164,14 @@ public:
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl);
void set_residual_callback(std::function<Tensor(const Tensor&)> cb);
public:
const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor&)> residual_callback;
private:
bool offload;
};
\ No newline at end of file
from types import MethodType
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.models.pulid.utils import resize_numpy_image_long
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline(
"A woman holding a sign that says hello world",
id_image=id_image,
id_weight=1,
num_inference_steps=12,
guidance_scale=3.5,
).images[0]
id_image = id_image.convert("RGB")
id_image_numpy = np.array(id_image)
id_image = resize_numpy_image_long(id_image_numpy, 1024)
id_embeddings, _ = pipeline.pulid_model.get_id_embedding(id_image)
output_image = image.convert("RGB")
output_image_numpy = np.array(output_image)
output_image = resize_numpy_image_long(output_image_numpy, 1024)
output_id_embeddings, _ = pipeline.pulid_model.get_id_embedding(output_image)
cosine_similarities = (
F.cosine_similarity(id_embeddings.view(32, 2048), output_id_embeddings.view(32, 2048), dim=1).mean().item()
)
print(cosine_similarities)
assert cosine_similarities > 0.93
......@@ -5,4 +5,8 @@ torchmetrics
mediapipe
controlnet_aux
peft
git+https://github.com/asomoza/image_gen_aux.git
\ No newline at end of file
git+https://github.com/asomoza/image_gen_aux.git
insightface
opencv-python
facexlib
onnxruntime
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment