Unverified Commit b737368d authored by K's avatar K Committed by GitHub
Browse files

feat: PuLID support (#274)



* add pulid

* Add the feature that allows the mixed use of pulid and non-pulid after loading pulid to generate the pipeline.

* Added the feature to load LoRA at any time.

* Organized the directory structure.

* Organized the code.

* Removed unused related code from eva-clip.

* style: apply Ruff formatting

* Refactored code and verified pulid works.

* add pulid tests

* auto detect precision in test

* Updated requirements.txt

* update requirements

* style: reformat the example

* style: reformat the example

* style: rename cb to call_back

* style: format the codes

* style: format the codes

* reformated the codes

* fix the repo forward

* clean some dead codes

* wrap up for pulid

---------
Co-authored-by: default avatarkkkxue <kkkxue@tencent.com>
Co-authored-by: default avatarmuyangli <lmxyy1999@foxmail.com>
parent b4d3f50b
# Adapted from https://github.com/ToTheBeginning/PuLID
import math
import cv2
import numpy as np
import torch
from torchvision.utils import make_grid
def resize_numpy_image_long(image, resize_long_edge=768):
h, w = image.shape[:2]
if max(h, w) <= resize_long_edge:
return image
k = resize_long_edge / max(h, w)
h = int(h * k)
w = int(w * k)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
return image
# from basicsr
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == "float64":
img = img.astype("float32")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f"Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}")
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
...@@ -63,6 +63,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -63,6 +63,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
id_embeddings=None,
id_weight=None,
joint_attention_kwargs=None, joint_attention_kwargs=None,
controlnet_block_samples=None, controlnet_block_samples=None,
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
...@@ -72,6 +74,12 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -72,6 +74,12 @@ class NunchakuFluxTransformerBlocks(nn.Module):
txt_tokens = encoder_hidden_states.shape[1] txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1] img_tokens = hidden_states.shape[1]
self.id_embeddings = id_embeddings
self.id_weight = id_weight
self.pulid_ca_idx = 0
if self.id_embeddings is not None :
self.set_residual_callback()
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
original_device = hidden_states.device original_device = hidden_states.device
...@@ -114,9 +122,13 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -114,9 +122,13 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_single, rotary_emb_single,
controlnet_block_samples, controlnet_block_samples,
controlnet_single_block_samples, controlnet_single_block_samples,
skip_first_layer, skip_first_layer
) )
if self.id_embeddings is not None :
self.reset_residual_callback()
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
encoder_hidden_states = hidden_states[:, :txt_tokens, ...] encoder_hidden_states = hidden_states[:, :txt_tokens, ...]
...@@ -179,7 +191,20 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -179,7 +191,20 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device) encoder_hidden_states = encoder_hidden_states.to(original_dtype).to(original_device)
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
def set_residual_callback(self):
id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx]
id_weight = self.id_weight
def callback(hidden_states):
ip = id_weight * pulid_ca[pulid_ca_idx[0]](id_embeddings, hidden_states.to("cuda"))
pulid_ca_idx[0] += 1
return ip
self.callback_holder = callback
self.m.set_residual_callback(callback)
def reset_residual_callback(self):
self.callback_holder = None
self.m.set_residual_callback(None)
def __del__(self): def __del__(self):
self.m.reset() self.m.reset()
...@@ -451,6 +476,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -451,6 +476,11 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0: if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
self._unquantized_part_loras = unquantized_part_loras self._unquantized_part_loras = unquantized_part_loras
self._unquantized_part_sd = {
k: v for k, v in self._unquantized_part_sd.items()
if "pulid_ca" not in k
}
self._update_unquantized_part_lora_params(1) self._update_unquantized_part_lora_params(1)
quantized_part_vectors = {} quantized_part_vectors = {}
......
from .pipeline_flux_pulid import PuLIDFluxPipeline
__all__ = ["PuLIDFluxPipeline"]
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
import gc
from typing import Any, Callable, Dict, List, Optional, Union
import cv2
import insightface
import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, EXAMPLE_DOC_STRING, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import (
replace_example_docstring,
)
from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import hf_hub_download, snapshot_download
from insightface.app import FaceAnalysis
from safetensors.torch import load_file
from torch import nn
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize
from ..models.pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from ..models.pulid.eva_clip import create_model_and_transforms
from ..models.pulid.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from ..models.pulid.utils import img2tensor, resize_numpy_image_long, tensor2img
class PuLIDPipeline(nn.Module):
def __init__(self, dit, device, weight_dtype=torch.bfloat16, onnx_provider="gpu", *args, **kwargs):
super().__init__()
self.device = device
self.weight_dtype = weight_dtype
double_interval = 2
single_interval = 4
# init encoder
self.pulid_encoder = IDFormer().to(self.device, self.weight_dtype)
num_ca = 19 // double_interval + 38 // single_interval
if 19 % double_interval != 0:
num_ca += 1
if 38 % single_interval != 0:
num_ca += 1
self.pulid_ca = nn.ModuleList(
[PerceiverAttentionCA().to(self.device, self.weight_dtype) for _ in range(num_ca)]
)
dit.transformer_blocks[0].pulid_ca = self.pulid_ca
# preprocessors
# face align and parsing
self.face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model="retinaface_resnet50",
save_ext="png",
device=self.device,
)
self.face_helper.face_parse = None
self.face_helper.face_parse = init_parsing_model(model_name="bisenet", device=self.device)
# clip-vit backbone
model, _, _ = create_model_and_transforms("EVA02-CLIP-L-14-336", "eva_clip", force_custom_clip=True)
model = model.visual
self.clip_vision_model = model.to(self.device, dtype=self.weight_dtype)
eva_transform_mean = getattr(self.clip_vision_model, "image_mean", OPENAI_DATASET_MEAN)
eva_transform_std = getattr(self.clip_vision_model, "image_std", OPENAI_DATASET_STD)
if not isinstance(eva_transform_mean, (list, tuple)):
eva_transform_mean = (eva_transform_mean,) * 3
if not isinstance(eva_transform_std, (list, tuple)):
eva_transform_std = (eva_transform_std,) * 3
self.eva_transform_mean = eva_transform_mean
self.eva_transform_std = eva_transform_std
# antelopev2
snapshot_download("DIAMONIK7777/antelopev2", local_dir="models/antelopev2")
providers = (
["CPUExecutionProvider"] if onnx_provider == "cpu" else ["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.app = FaceAnalysis(name="antelopev2", root=".", providers=providers)
self.app.prepare(ctx_id=0, det_size=(640, 640))
self.handler_ante = insightface.model_zoo.get_model("models/antelopev2/glintr100.onnx", providers=providers)
self.handler_ante.prepare(ctx_id=0)
gc.collect()
torch.cuda.empty_cache()
# other configs
self.debug_img_list = []
def load_pretrain(self, pretrain_path=None, version="v0.9.0"):
hf_hub_download("guozinan/PuLID", f"pulid_flux_{version}.safetensors", local_dir="models")
ckpt_path = f"models/pulid_flux_{version}.safetensors"
if pretrain_path is not None:
ckpt_path = pretrain_path
state_dict = load_file(ckpt_path)
state_dict_dict = {}
for k, v in state_dict.items():
module = k.split(".")[0]
state_dict_dict.setdefault(module, {})
new_k = k[len(module) + 1 :]
state_dict_dict[module][new_k] = v
for module in state_dict_dict:
print(f"loading from {module}")
getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)
del state_dict
del state_dict_dict
def to_gray(self, img):
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
x = x.repeat(1, 3, 1, 1)
return x
@torch.no_grad()
def get_id_embedding(self, image, cal_uncond=False):
"""
Args:
image: numpy rgb image, range [0, 255]
"""
self.face_helper.clean_all()
self.debug_img_list = []
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# get antelopev2 embedding
face_info = self.app.get(image_bgr)
if len(face_info) > 0:
face_info = sorted(face_info, key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]))[
-1
] # only use the maximum face
id_ante_embedding = face_info["embedding"]
self.debug_img_list.append(
image[
int(face_info["bbox"][1]) : int(face_info["bbox"][3]),
int(face_info["bbox"][0]) : int(face_info["bbox"][2]),
]
)
else:
id_ante_embedding = None
# using facexlib to detect and align face
self.face_helper.read_image(image_bgr)
self.face_helper.get_face_landmarks_5(only_center_face=True)
self.face_helper.align_warp_face()
if len(self.face_helper.cropped_faces) == 0:
raise RuntimeError("facexlib align face fail")
align_face = self.face_helper.cropped_faces[0]
# incase insightface didn't detect face
if id_ante_embedding is None:
print("fail to detect face using insightface, extract embedding on align face")
id_ante_embedding = self.handler_ante.get_feat(align_face)
id_ante_embedding = torch.from_numpy(id_ante_embedding).to(self.device, self.weight_dtype)
if id_ante_embedding.ndim == 1:
id_ante_embedding = id_ante_embedding.unsqueeze(0)
# parsing
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0
input = input.to(self.device)
parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
parsing_out = parsing_out.argmax(dim=1, keepdim=True)
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
bg = sum(parsing_out == i for i in bg_label).bool()
white_image = torch.ones_like(input)
# only keep the face features
face_features_image = torch.where(bg, white_image, self.to_gray(input))
self.debug_img_list.append(tensor2img(face_features_image, rgb2bgr=False))
# transform img before sending to eva-clip-vit
face_features_image = resize(face_features_image, self.clip_vision_model.image_size, InterpolationMode.BICUBIC)
face_features_image = normalize(face_features_image, self.eva_transform_mean, self.eva_transform_std)
id_cond_vit, id_vit_hidden = self.clip_vision_model(
face_features_image.to(self.weight_dtype), return_all_features=False, return_hidden=True, shuffle=False
)
id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1)
id_embedding = self.pulid_encoder(id_cond, id_vit_hidden)
if not cal_uncond:
return id_embedding, None
id_uncond = torch.zeros_like(id_cond)
id_vit_hidden_uncond = []
for layer_idx in range(0, len(id_vit_hidden)):
id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden[layer_idx]))
uncond_id_embedding = self.pulid_encoder(id_uncond, id_vit_hidden_uncond)
return id_embedding, uncond_id_embedding
class PuLIDFluxPipeline(FluxPipeline):
def __init__(
self,
scheduler,
vae,
text_encoder,
tokenizer,
text_encoder_2,
tokenizer_2,
transformer,
image_encoder=None,
feature_extractor=None,
pulid_device="cuda",
weight_dtype=torch.bfloat16,
onnx_provider="gpu",
pretrained_model=None,
):
super().__init__(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
# Save custom parameters
self.pulid_device = torch.device(pulid_device)
self.weight_dtype = weight_dtype
self.onnx_provider = onnx_provider
# Init PuLID pipeline (injects ID encoder into transformer)
self.pulid_model = PuLIDPipeline(
dit=self.transformer, # directly mutate transformer with pulid_ca
device=self.pulid_device,
weight_dtype=self.weight_dtype,
onnx_provider=self.onnx_provider,
)
self.pulid_model.load_pretrain(pretrained_model)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
id_image=None,
id_weight=1.0,
start_step=0,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
if id_image is not None:
# pil_image = Image.open(id_image)
pil_image = id_image.convert("RGB")
numpy_image = np.array(pil_image)
id_image = resize_numpy_image_long(numpy_image, 1024)
id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image)
else:
id_embeddings = None
uncond_id_embeddings = None
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
id_embeddings=id_embeddings,
id_weight=id_weight,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents,
id_embeddings=id_embeddings,
id_weight=id_weight,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
# pulid related
insightface
opencv-python
facexlib
onnxruntime
\ No newline at end of file
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
#include "kernels/zgemm/zgemm.h" #include "kernels/zgemm/zgemm.h"
#include "flash_api.h" #include "flash_api.h"
#include "activation.h" #include "activation.h"
#include <nvtx3/nvToolsExt.h> #include <nvtx3/nvToolsExt.h>
#include <pybind11/functional.h>
#include <iostream> #include <iostream>
using spdlog::fmt_lib::format; using spdlog::fmt_lib::format;
...@@ -819,6 +820,13 @@ Tensor FluxModel::forward( ...@@ -819,6 +820,13 @@ Tensor FluxModel::forward(
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]); hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} }
if (residual_callback && layer % 2 == 0) {
Tensor cpu_input = hidden_states.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
hidden_states = kernels::add(hidden_states, residual);
}
} else { } else {
if (size_t(layer) == transformer_blocks.size()) { if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers // txt first, same as diffusers
...@@ -845,6 +853,17 @@ Tensor FluxModel::forward( ...@@ -845,6 +853,17 @@ Tensor FluxModel::forward(
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens); auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]); slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice); hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
size_t local_layer_idx = layer - transformer_blocks.size();
if (residual_callback && local_layer_idx % 4 == 0) {
Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Tensor cpu_input = callback_input.copy(Device::cpu());
pybind11::gil_scoped_acquire gil;
Tensor cpu_output = residual_callback(cpu_input);
Tensor residual = cpu_output.copy(Device::cuda());
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, residual);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
} }
} }
}; };
...@@ -936,3 +955,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) { ...@@ -936,3 +955,6 @@ void FluxModel::setAttentionImpl(AttentionImpl impl) {
block->attnImpl = impl; block->attnImpl = impl;
} }
} }
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor&)> cb) {
residual_callback = std::move(cb);
}
\ No newline at end of file
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "Module.h" #include "Module.h"
#include "Linear.h" #include "Linear.h"
#include "layernorm.h" #include "layernorm.h"
#include <pybind11/functional.h>
namespace pybind11 {
class function;
}
enum class AttentionImpl { enum class AttentionImpl {
FlashAttention2 = 0, FlashAttention2 = 0,
...@@ -160,12 +164,14 @@ public: ...@@ -160,12 +164,14 @@ public:
Tensor controlnet_single_block_samples); Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl); void setAttentionImpl(AttentionImpl impl);
void set_residual_callback(std::function<Tensor(const Tensor&)> cb);
public: public:
const Tensor::ScalarType dtype; const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
std::function<Tensor(const Tensor&)> residual_callback;
private: private:
bool offload; bool offload;
}; };
\ No newline at end of file
from types import MethodType
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.pulid.pulid_forward import pulid_forward
from nunchaku.models.pulid.utils import resize_numpy_image_long
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_pulid():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"mit-han-lab/svdq-{precision}-flux.1-dev")
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.forward = MethodType(pulid_forward, pipeline.transformer)
id_image = load_image("https://github.com/ToTheBeginning/PuLID/blob/main/example_inputs/liuyifei.png?raw=true")
image = pipeline(
"A woman holding a sign that says hello world",
id_image=id_image,
id_weight=1,
num_inference_steps=12,
guidance_scale=3.5,
).images[0]
id_image = id_image.convert("RGB")
id_image_numpy = np.array(id_image)
id_image = resize_numpy_image_long(id_image_numpy, 1024)
id_embeddings, _ = pipeline.pulid_model.get_id_embedding(id_image)
output_image = image.convert("RGB")
output_image_numpy = np.array(output_image)
output_image = resize_numpy_image_long(output_image_numpy, 1024)
output_id_embeddings, _ = pipeline.pulid_model.get_id_embedding(output_image)
cosine_similarities = (
F.cosine_similarity(id_embeddings.view(32, 2048), output_id_embeddings.view(32, 2048), dim=1).mean().item()
)
print(cosine_similarities)
assert cosine_similarities > 0.93
...@@ -5,4 +5,8 @@ torchmetrics ...@@ -5,4 +5,8 @@ torchmetrics
mediapipe mediapipe
controlnet_aux controlnet_aux
peft peft
git+https://github.com/asomoza/image_gen_aux.git git+https://github.com/asomoza/image_gen_aux.git
\ No newline at end of file insightface
opencv-python
facexlib
onnxruntime
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment