Commit 235238bd authored by Hyunsung Lee's avatar Hyunsung Lee Committed by Zhekai Zhang
Browse files

Add controlnet

parent 63913f29
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev",
torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
transformer=transformer,
controlnet=controlnet,
torch_dtype=torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("nunchaku-controlnet-flux.1-dev.png")
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev",
torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
transformer=transformer,
controlnet=controlnet,
torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("nunchaku-controlnet-flux.1-dev.png")
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("reference-controlnet-flux.1-dev.png")
...@@ -35,6 +35,8 @@ public: ...@@ -35,6 +35,8 @@ public:
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single, torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false) bool skip_first_layer = false)
{ {
checkModel(); checkModel();
...@@ -56,6 +58,8 @@ public: ...@@ -56,6 +58,8 @@ public:
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
from_torch(rotary_emb_single), from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer skip_first_layer
); );
...@@ -71,7 +75,9 @@ public: ...@@ -71,7 +75,9 @@ public:
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context) torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{ {
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
...@@ -83,17 +89,19 @@ public: ...@@ -83,17 +89,19 @@ public:
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward( auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx,
from_torch(hidden_states), from_torch(hidden_states),
from_torch(encoder_hidden_states), from_torch(encoder_hidden_states),
from_torch(temb), from_torch(temb),
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
0.0f controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
); );
hidden_states = to_torch(result_img); hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(result_txt); encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
......
...@@ -26,8 +26,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -26,8 +26,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("dict"), py::arg("dict"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("forward", &QuantizedFluxModel::forward) .def("forward", &QuantizedFluxModel::forward,
.def("forward_layer", &QuantizedFluxModel::forward_layer) py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug) .def("stopDebug", &QuantizedFluxModel::stopDebug)
......
from typing import Any, Dict, Optional, Union
import logging import logging
import os import os
...@@ -5,6 +7,7 @@ import diffusers ...@@ -5,6 +7,7 @@ import diffusers
import torch import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from huggingface_hub import utils from huggingface_hub import utils
from packaging.version import Version from packaging.version import Version
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
...@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None, joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False, skip_first_layer=False,
): ):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb = temb.to(self.dtype).to(self.device) temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[1] == 1
...@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1)) rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1)) rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
hidden_states = self.m.forward( hidden_states = self.m.forward(
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
...@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img, rotary_emb_img,
rotary_emb_txt, rotary_emb_txt,
rotary_emb_single, rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer, skip_first_layer,
) )
...@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None, joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None
): ):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1] txt_tokens = encoder_hidden_states.shape[1]
...@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb = temb.to(self.dtype).to(self.device) temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[1] == 1
...@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
hidden_states, encoder_hidden_states = self.m.forward_layer( hidden_states, encoder_hidden_states = self.m.forward_layer(
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt,
controlnet_block_samples, controlnet_single_block_samples
) )
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
...@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
state_dict.update(updated_vectors) state_dict.update(updated_vectors)
self.transformer_blocks[0].m.loadDict(state_dict, True) self.transformer_blocks[0].m.loadDict(state_dict, True)
self.reset_x_embedder() self.reset_x_embedder()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
nunchaku_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = nunchaku_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic ...@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
} }
} }
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer) { Tensor FluxModel::forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer) {
const int batch_size = hidden_states.shape[0]; const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype(); const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device(); const Device device = hidden_states.device();
...@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int img_tokens = hidden_states.shape[1]; const int img_tokens = hidden_states.shape[1];
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size(); const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
Tensor concat; Tensor concat;
...@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
if (size_t(layer) < transformer_blocks.size()) { if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer); auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) {
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
} else { } else {
if (size_t(layer) == transformer_blocks.size()) { if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers // txt first, same as diffusers
...@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
} }
hidden_states = concat; hidden_states = concat;
encoder_hidden_states = {}; encoder_hidden_states = {};
} }
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size()); auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single); hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
if (controlnet_single_block_samples.valid()) {
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
} }
}; };
auto load = [&](int layer) { auto load = [&](int layer) {
...@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
return hidden_states; return hidden_states;
} }
std::tuple<Tensor, Tensor> FluxModel::forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
return { hidden_states, encoder_hidden_states };
}
void FluxModel::setAttentionImpl(AttentionImpl impl) { void FluxModel::setAttentionImpl(AttentionImpl impl) {
for (auto &&block : this->transformer_blocks) { for (auto &&block : this->transformer_blocks) {
block->attnImpl = impl; block->attnImpl = impl;
......
...@@ -138,8 +138,25 @@ private: ...@@ -138,8 +138,25 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer = false); Tensor forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl); void setAttentionImpl(AttentionImpl impl);
public: public:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment