Commit 235238bd authored by Hyunsung Lee's avatar Hyunsung Lee Committed by Zhekai Zhang
Browse files

Add controlnet

parent 63913f29
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev",
torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
transformer=transformer,
controlnet=controlnet,
torch_dtype=torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("nunchaku-controlnet-flux.1-dev.png")
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev",
torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
transformer=transformer,
controlnet=controlnet,
torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("nunchaku-controlnet-flux.1-dev.png")
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("reference-controlnet-flux.1-dev.png")
...@@ -35,6 +35,8 @@ public: ...@@ -35,6 +35,8 @@ public:
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context, torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single, torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false) bool skip_first_layer = false)
{ {
checkModel(); checkModel();
...@@ -56,6 +58,8 @@ public: ...@@ -56,6 +58,8 @@ public:
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
from_torch(rotary_emb_single), from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer skip_first_layer
); );
...@@ -71,7 +75,9 @@ public: ...@@ -71,7 +75,9 @@ public:
torch::Tensor encoder_hidden_states, torch::Tensor encoder_hidden_states,
torch::Tensor temb, torch::Tensor temb,
torch::Tensor rotary_emb_img, torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context) torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{ {
CUDADeviceContext ctx(deviceId); CUDADeviceContext ctx(deviceId);
...@@ -83,17 +89,19 @@ public: ...@@ -83,17 +89,19 @@ public:
rotary_emb_img = rotary_emb_img.contiguous(); rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous(); rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward( auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx,
from_torch(hidden_states), from_torch(hidden_states),
from_torch(encoder_hidden_states), from_torch(encoder_hidden_states),
from_torch(temb), from_torch(temb),
from_torch(rotary_emb_img), from_torch(rotary_emb_img),
from_torch(rotary_emb_context), from_torch(rotary_emb_context),
0.0f controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
); );
hidden_states = to_torch(result_img); hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(result_txt); encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states }; return { hidden_states, encoder_hidden_states };
......
...@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> { ...@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public: public:
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) { void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM"); spdlog::info("Initializing QuantizedGEMM");
size_t val = 0; size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192)); checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize)); checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
x = x.contiguous(); x = x.contiguous();
Tensor result = net->forward(from_torch(x)); Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result); torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice(); Tensor::synchronizeDevice();
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
const int M = x.shape[0]; const int M = x.shape[0];
const int K = x.shape[1] * 2; const int K = x.shape[1] * 2;
assert(x.dtype() == Tensor::INT8); assert(x.dtype() == Tensor::INT8);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4) // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
...@@ -83,7 +83,7 @@ public: ...@@ -83,7 +83,7 @@ public:
} }
} }
} }
ss << std::endl; ss << std::endl;
return ss.str(); return ss.str();
} }
...@@ -99,7 +99,7 @@ public: ...@@ -99,7 +99,7 @@ public:
from_torch(x), from_torch(x),
fuse_glu fuse_glu
); );
Tensor act = qout.act.copy(Device::cpu()); Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu()); Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu()); Tensor lora_act = qout.lora_act.copy(Device::cpu());
......
...@@ -18,16 +18,35 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -18,16 +18,35 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId") py::arg("deviceId")
) )
.def("reset", &QuantizedFluxModel::reset) .def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load, .def("load", &QuantizedFluxModel::load,
py::arg("path"), py::arg("path"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("loadDict", &QuantizedFluxModel::loadDict, .def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"), py::arg("dict"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("forward", &QuantizedFluxModel::forward) .def("forward", &QuantizedFluxModel::forward,
.def("forward_layer", &QuantizedFluxModel::forward_layer) py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug) .def("stopDebug", &QuantizedFluxModel::stopDebug)
...@@ -46,11 +65,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -46,11 +65,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId") py::arg("deviceId")
) )
.def("reset", &QuantizedSanaModel::reset) .def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load, .def("load", &QuantizedSanaModel::load,
py::arg("path"), py::arg("path"),
py::arg("partial") = false py::arg("partial") = false
) )
.def("loadDict", &QuantizedSanaModel::loadDict, .def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"), py::arg("dict"),
py::arg("partial") = false py::arg("partial") = false
) )
......
from typing import Any, Dict, Optional, Union
import logging import logging
import os import os
...@@ -5,6 +7,7 @@ import diffusers ...@@ -5,6 +7,7 @@ import diffusers
import torch import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from huggingface_hub import utils from huggingface_hub import utils
from packaging.version import Version from packaging.version import Version
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
...@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None, joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False, skip_first_layer=False,
): ):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb = temb.to(self.dtype).to(self.device) temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[1] == 1
...@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1)) rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1)) rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
hidden_states = self.m.forward( hidden_states = self.m.forward(
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
...@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img, rotary_emb_img,
rotary_emb_txt, rotary_emb_txt,
rotary_emb_single, rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer, skip_first_layer,
) )
...@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: torch.Tensor, image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None, joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None
): ):
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1] txt_tokens = encoder_hidden_states.shape[1]
...@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb = temb.to(self.dtype).to(self.device) temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device) image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
assert image_rotary_emb.ndim == 6 assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1 assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1 assert image_rotary_emb.shape[1] == 1
...@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module): ...@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1)) rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
hidden_states, encoder_hidden_states = self.m.forward_layer( hidden_states, encoder_hidden_states = self.m.forward_layer(
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt,
controlnet_block_samples, controlnet_single_block_samples
) )
hidden_states = hidden_states.to(original_dtype).to(original_device) hidden_states = hidden_states.to(original_dtype).to(original_device)
...@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader ...@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
state_dict.update(updated_vectors) state_dict.update(updated_vectors)
self.transformer_blocks[0].m.loadDict(state_dict, True) self.transformer_blocks[0].m.loadDict(state_dict, True)
self.reset_x_embedder() self.reset_x_embedder()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
nunchaku_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = nunchaku_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -40,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) { ...@@ -40,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) : AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
linear(dim, 3 * dim, true, dtype, device), linear(dim, 3 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device) norm(dim, 1e-6, false, dtype, device)
{ {
registerChildren registerChildren
(linear, "linear") (linear, "linear")
...@@ -59,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor ...@@ -59,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
debug("x", x); debug("x", x);
Tensor norm_x = norm.forward(x); Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x); debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa); kernels::mul_add(norm_x, scale_msa, shift_msa);
return Output{norm_x, gate_msa}; return Output{norm_x, gate_msa};
} }
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) : AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
dim(dim), pre_only(pre_only), dim(dim), pre_only(pre_only),
linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device) norm(dim, 1e-6, false, dtype, device)
...@@ -91,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -91,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
kernels::mul_add(norm_x, scale_msa, shift_msa); kernels::mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x); debug("norm_x_scaled", norm_x);
return Output{norm_x}; return Output{norm_x};
} else { } else {
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb); auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
...@@ -108,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -108,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
} }
Attention::Attention(int num_heads, int dim_head, Device device) : Attention::Attention(int num_heads, int dim_head, Device device) :
num_heads(num_heads), dim_head(dim_head), force_fp16(false) num_heads(num_heads), dim_head(dim_head), force_fp16(false)
{ {
headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu()); headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
...@@ -151,7 +151,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -151,7 +151,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
gemm_batched_fp16(pool_q, pool_k, pool_s); gemm_batched_fp16(pool_q, pool_k, pool_s);
} }
} }
blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio)); blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
if (cu_seqlens_cpu.valid()) { if (cu_seqlens_cpu.valid()) {
...@@ -227,9 +227,9 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -227,9 +227,9 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false false
).front(); ).front();
Tensor raw_attn_output = mha_fwd(q, k, v, Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f, 0.0f,
pow(q.shape[-1], (-0.5)), pow(q.shape[-1], (-0.5)),
false, -1, -1, false false, -1, -1, false
).front(); ).front();
...@@ -261,7 +261,7 @@ void Attention::setForceFP16(Module *module, bool value) { ...@@ -261,7 +261,7 @@ void Attention::setForceFP16(Module *module, bool value) {
} }
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) : FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio), mlp_hidden_dim(dim * mlp_ratio),
...@@ -311,7 +311,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -311,7 +311,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb); qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv); debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states); // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output = attn.forward(qkv, {}, 0); attn_output = attn.forward(qkv, {}, 0);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head}); attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) { } else if (attnImpl == AttentionImpl::NunchakuFP16) {
...@@ -340,7 +340,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -340,7 +340,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug("raw_attn_output", attn_output); debug("raw_attn_output", attn_output);
attn_output = forward_fc(out_proj, attn_output); attn_output = forward_fc(out_proj, attn_output);
debug("attn_output", attn_output); debug("attn_output", attn_output);
...@@ -350,7 +350,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -350,7 +350,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states = kernels::add(attn_output, ff_output); hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states); debug("attn_ff_output", hidden_states);
kernels::mul_add(hidden_states, gate, residual); kernels::mul_add(hidden_states, gate, residual);
nvtxRangePop(); nvtxRangePop();
...@@ -358,7 +358,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te ...@@ -358,7 +358,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states; return hidden_states;
} }
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) : JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads), num_heads(num_attention_heads),
...@@ -416,7 +416,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -416,7 +416,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int num_tokens_img = hidden_states.shape[1]; int num_tokens_img = hidden_states.shape[1];
int num_tokens_txt = encoder_hidden_states.shape[1]; int num_tokens_txt = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim); assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim); assert(encoder_hidden_states.shape[2] == dim);
...@@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop(); nvtxRangePop();
auto stream = getCurrentCUDAStream(); auto stream = getCurrentCUDAStream();
int num_tokens_img_pad = 0, num_tokens_txt_pad = 0; int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
Tensor raw_attn_output; Tensor raw_attn_output;
...@@ -449,66 +449,66 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -449,66 +449,66 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor concat; Tensor concat;
Tensor pool; Tensor pool;
{ {
nvtxRangePushA("qkv_proj"); nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0; const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE; const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device()); concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
pool = blockSparse pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device()) ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{}; : Tensor{};
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// img first // img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img); Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt); Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv = pool.valid() Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{}; : Tensor{};
Tensor pool_qkv_context = pool.valid() Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE) ? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{}; : Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv); // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv); // debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb); debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb); qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv); debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context); // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context); // debug("qkv_context_raw", qkv_context);
debug("rotary_emb_context", rotary_emb_context); debug("rotary_emb_context", rotary_emb_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context); qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
debug("qkv_context", qkv_context); debug("qkv_context", qkv_context);
} }
nvtxRangePop(); nvtxRangePop();
} }
spdlog::debug("concat={}", concat.shape.str()); spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat); debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3); assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention"); nvtxRangePushA("Attention");
raw_attn_output = attn.forward(concat, pool, sparsityRatio); raw_attn_output = attn.forward(concat, pool, sparsityRatio);
nvtxRangePop(); nvtxRangePop();
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str()); spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head}); raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) { } else if (attnImpl == AttentionImpl::NunchakuFP16) {
num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256; num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256; num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
...@@ -517,11 +517,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -517,11 +517,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
{ {
nvtxRangePushA("qkv_proj"); nvtxRangePushA("qkv_proj");
concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device()); concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device());
concat_k = Tensor::empty_like(concat_q); concat_k = Tensor::empty_like(concat_q);
concat_v = Tensor::empty_like(concat_q); concat_v = Tensor::empty_like(concat_q);
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// img first // img first
auto sliceImg = [&](Tensor x) { auto sliceImg = [&](Tensor x) {
...@@ -530,12 +530,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -530,12 +530,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto sliceTxt = [&](Tensor x) { auto sliceTxt = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad); return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
}; };
qkv_proj.forward( qkv_proj.forward(
norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb, norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img
); );
qkv_proj_context.forward( qkv_proj_context.forward(
norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context, norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context,
sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt
...@@ -545,7 +545,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, ...@@ -545,7 +545,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("concat_q", concat_q); debug("concat_q", concat_q);
debug("concat_k", concat_k); debug("concat_k", concat_k);
debug("concat_v", concat_v); debug("concat_v", concat_v);
nvtxRangePop(); nvtxRangePop();
} }
...@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic ...@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
} }
} }
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer) { Tensor FluxModel::forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer) {
const int batch_size = hidden_states.shape[0]; const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype(); const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device(); const Device device = hidden_states.device();
...@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int img_tokens = hidden_states.shape[1]; const int img_tokens = hidden_states.shape[1];
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size(); const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
Tensor concat; Tensor concat;
...@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
if (size_t(layer) < transformer_blocks.size()) { if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer); auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f); std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) {
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
} else { } else {
if (size_t(layer) == transformer_blocks.size()) { if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers // txt first, same as diffusers
...@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
} }
hidden_states = concat; hidden_states = concat;
encoder_hidden_states = {}; encoder_hidden_states = {};
} }
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size()); auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single); hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
if (controlnet_single_block_samples.valid()) {
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
} }
}; };
auto load = [&](int layer) { auto load = [&](int layer) {
...@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te ...@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
return hidden_states; return hidden_states;
} }
std::tuple<Tensor, Tensor> FluxModel::forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
return { hidden_states, encoder_hidden_states };
}
void FluxModel::setAttentionImpl(AttentionImpl impl) { void FluxModel::setAttentionImpl(AttentionImpl impl) {
for (auto &&block : this->transformer_blocks) { for (auto &&block : this->transformer_blocks) {
block->attnImpl = impl; block->attnImpl = impl;
......
...@@ -61,7 +61,7 @@ private: ...@@ -61,7 +61,7 @@ private:
class Attention : public Module { class Attention : public Module {
public: public:
static constexpr int POOL_SIZE = 128; static constexpr int POOL_SIZE = 128;
Attention(int num_heads, int dim_head, Device device); Attention(int num_heads, int dim_head, Device device);
Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio); Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);
...@@ -138,13 +138,30 @@ private: ...@@ -138,13 +138,30 @@ private:
class FluxModel : public Module { class FluxModel : public Module {
public: public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device); FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer = false); Tensor forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl); void setAttentionImpl(AttentionImpl impl);
public: public:
const Tensor::ScalarType dtype; const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks; std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks; std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
......
...@@ -13,7 +13,7 @@ public: ...@@ -13,7 +13,7 @@ public:
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU; this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device(); this->device.idx = this->tensor.get_device();
} }
virtual bool isAsyncBuffer() override { virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory // TODO: figure out how torch manages memory
return this->device.type == Device::CUDA; return this->device.type == Device::CUDA;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment