Commit 235238bd authored by Hyunsung Lee's avatar Hyunsung Lee Committed by Zhekai Zhang
Browse files

Add controlnet

parent 63913f29
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev",
torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
transformer=transformer,
controlnet=controlnet,
torch_dtype=torch.bfloat16)
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("nunchaku-controlnet-flux.1-dev.png")
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
"mit-han-lab/svdq-int4-flux.1-dev",
torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxControlNetPipeline.from_pretrained(
base_model,
transformer=transformer,
controlnet=controlnet,
torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("nunchaku-controlnet-flux.1-dev.png")
import random
import torch
from diffusers import FluxControlNetPipeline, FluxControlNetModel
from diffusers.models import FluxMultiControlNetModel
from nunchaku import NunchakuFluxTransformer2dModel
from diffusers.utils import load_image
import numpy as np
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
base_model = 'black-forest-labs/FLUX.1-dev'
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend loading via FluxMultiControlNetModel
pipe = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = 'A anime style girl with messy beach waves.'
control_image_depth = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/depth.jpg")
control_mode_depth = 2
control_image_canny = load_image("https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/canny.jpg")
control_mode_canny = 0
width, height = control_image_depth.size
image = pipe(
prompt,
control_image=[control_image_depth, control_image_canny],
control_mode=[control_mode_depth, control_mode_canny],
width=width,
height=height,
controlnet_conditioning_scale=[0.3, 0.1],
num_inference_steps=28,
guidance_scale=3.5,
generator=torch.manual_seed(SEED),
).images[0]
image.save("reference-controlnet-flux.1-dev.png")
......@@ -35,6 +35,8 @@ public:
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
torch::Tensor rotary_emb_single,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt,
bool skip_first_layer = false)
{
checkModel();
......@@ -56,6 +58,8 @@ public:
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
from_torch(rotary_emb_single),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{},
skip_first_layer
);
......@@ -71,7 +75,9 @@ public:
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context)
torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt)
{
CUDADeviceContext ctx(deviceId);
......@@ -83,17 +89,19 @@ public:
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
auto &&[result_img, result_txt] = net->transformer_blocks.at(idx)->forward(
auto &&[hidden_states_, encoder_hidden_states_] = net->forward_layer(
idx,
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
0.0f
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value() ? from_torch(controlnet_single_block_samples.value().contiguous()) : Tensor{}
);
hidden_states = to_torch(result_img);
encoder_hidden_states = to_torch(result_txt);
hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_);
Tensor::synchronizeDevice();
return { hidden_states, encoder_hidden_states };
......
......@@ -10,7 +10,7 @@ class QuantizedGEMM : public ModuleWrapper<GEMM_W4A4> {
public:
void init(int64_t in_features, int64_t out_features, bool bias, bool use_fp4, bool bf16, int8_t deviceId) {
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
......@@ -27,7 +27,7 @@ public:
x = x.contiguous();
Tensor result = net->forward(from_torch(x));
torch::Tensor output = to_torch(result);
Tensor::synchronizeDevice();
......@@ -48,7 +48,7 @@ public:
const int M = x.shape[0];
const int K = x.shape[1] * 2;
assert(x.dtype() == Tensor::INT8);
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t (uint4)
......@@ -83,7 +83,7 @@ public:
}
}
}
ss << std::endl;
return ss.str();
}
......@@ -99,7 +99,7 @@ public:
from_torch(x),
fuse_glu
);
Tensor act = qout.act.copy(Device::cpu());
Tensor ascales = qout.ascales.copy(Device::cpu());
Tensor lora_act = qout.lora_act.copy(Device::cpu());
......
......@@ -18,16 +18,35 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId")
)
.def("reset", &QuantizedFluxModel::reset)
.def("load", &QuantizedFluxModel::load,
.def("load", &QuantizedFluxModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedFluxModel::loadDict,
.def("loadDict", &QuantizedFluxModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
.def("forward", &QuantizedFluxModel::forward)
.def("forward_layer", &QuantizedFluxModel::forward_layer)
.def("forward", &QuantizedFluxModel::forward,
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("rotary_emb_single"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none(),
py::arg("skip_first_layer") = false
)
.def("forward_layer", &QuantizedFluxModel::forward_layer,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()
)
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("startDebug", &QuantizedFluxModel::startDebug)
.def("stopDebug", &QuantizedFluxModel::stopDebug)
......@@ -46,11 +65,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("deviceId")
)
.def("reset", &QuantizedSanaModel::reset)
.def("load", &QuantizedSanaModel::load,
.def("load", &QuantizedSanaModel::load,
py::arg("path"),
py::arg("partial") = false
)
.def("loadDict", &QuantizedSanaModel::loadDict,
.def("loadDict", &QuantizedSanaModel::loadDict,
py::arg("dict"),
py::arg("partial") = false
)
......
from typing import Any, Dict, Optional, Union
import logging
import os
......@@ -5,6 +7,7 @@ import diffusers
import torch
from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from huggingface_hub import utils
from packaging.version import Version
from safetensors.torch import load_file, save_file
......@@ -62,6 +65,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
):
batch_size = hidden_states.shape[0]
......@@ -76,6 +81,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
......@@ -89,7 +99,6 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
hidden_states = self.m.forward(
hidden_states,
encoder_hidden_states,
......@@ -97,6 +106,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img,
rotary_emb_txt,
rotary_emb_single,
controlnet_block_samples,
controlnet_single_block_samples,
skip_first_layer,
)
......@@ -115,6 +126,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb: torch.Tensor,
image_rotary_emb: torch.Tensor,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None
):
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
......@@ -128,6 +141,11 @@ class NunchakuFluxTransformerBlocks(nn.Module):
temb = temb.to(self.dtype).to(self.device)
image_rotary_emb = image_rotary_emb.to(self.device)
if controlnet_block_samples is not None:
controlnet_block_samples = torch.stack(controlnet_block_samples).to(self.device)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = torch.stack(controlnet_single_block_samples).to(self.device)
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
......@@ -141,7 +159,8 @@ class NunchakuFluxTransformerBlocks(nn.Module):
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
hidden_states, encoder_hidden_states = self.m.forward_layer(
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt
idx, hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_txt,
controlnet_block_samples, controlnet_single_block_samples
)
hidden_states = hidden_states.to(original_dtype).to(original_device)
......@@ -473,3 +492,100 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
state_dict.update(updated_vectors)
self.transformer_blocks[0].m.loadDict(state_dict, True)
self.reset_x_embedder()
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
nunchaku_block = self.transformer_blocks[0]
encoder_hidden_states, hidden_states = nunchaku_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples
)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -40,7 +40,7 @@ Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) :
dim(dim),
linear(dim, 3 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device)
norm(dim, 1e-6, false, dtype, device)
{
registerChildren
(linear, "linear")
......@@ -59,12 +59,12 @@ AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor
debug("x", x);
Tensor norm_x = norm.forward(x);
debug("norm_x", norm_x);
kernels::mul_add(norm_x, scale_msa, shift_msa);
return Output{norm_x, gate_msa};
}
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
dim(dim), pre_only(pre_only),
linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
norm(dim, 1e-6, false, dtype, device)
......@@ -91,7 +91,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
kernels::mul_add(norm_x, scale_msa, shift_msa);
debug("norm_x_scaled", norm_x);
return Output{norm_x};
} else {
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
......@@ -108,7 +108,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
}
Attention::Attention(int num_heads, int dim_head, Device device) :
Attention::Attention(int num_heads, int dim_head, Device device) :
num_heads(num_heads), dim_head(dim_head), force_fp16(false)
{
headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
......@@ -151,7 +151,7 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
gemm_batched_fp16(pool_q, pool_k, pool_s);
}
}
blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
if (cu_seqlens_cpu.valid()) {
......@@ -227,9 +227,9 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false
).front();
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
pow(q.shape[-1], (-0.5)),
Tensor raw_attn_output = mha_fwd(q, k, v,
0.0f,
pow(q.shape[-1], (-0.5)),
false, -1, -1, false
).front();
......@@ -261,7 +261,7 @@ void Attention::setForceFP16(Module *module, bool value) {
}
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim(dim),
dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads),
mlp_hidden_dim(dim * mlp_ratio),
......@@ -311,7 +311,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
attn_output = attn.forward(qkv, {}, 0);
attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
......@@ -340,7 +340,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
debug("raw_attn_output", attn_output);
attn_output = forward_fc(out_proj, attn_output);
debug("attn_output", attn_output);
......@@ -350,7 +350,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
hidden_states = kernels::add(attn_output, ff_output);
debug("attn_ff_output", hidden_states);
kernels::mul_add(hidden_states, gate, residual);
nvtxRangePop();
......@@ -358,7 +358,7 @@ Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Te
return hidden_states;
}
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
dim(dim),
dim_head(attention_head_dim / num_attention_heads),
num_heads(num_attention_heads),
......@@ -416,7 +416,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
int num_tokens_img = hidden_states.shape[1];
int num_tokens_txt = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim);
......@@ -439,7 +439,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
nvtxRangePop();
auto stream = getCurrentCUDAStream();
int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
Tensor raw_attn_output;
......@@ -449,66 +449,66 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
pool = blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{};
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv = pool.valid()
? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
: Tensor{};
Tensor pool_qkv_context = pool.valid()
? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb);
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
debug("rotary_emb_context", rotary_emb_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
debug("qkv_context", qkv_context);
}
nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention");
raw_attn_output = attn.forward(concat, pool, sparsityRatio);
nvtxRangePop();
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
......@@ -517,11 +517,11 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
{
nvtxRangePushA("qkv_proj");
concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device());
concat_k = Tensor::empty_like(concat_q);
concat_v = Tensor::empty_like(concat_q);
for (int i = 0; i < batch_size; i++) {
// img first
auto sliceImg = [&](Tensor x) {
......@@ -530,12 +530,12 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
auto sliceTxt = [&](Tensor x) {
return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
};
qkv_proj.forward(
norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img
);
qkv_proj_context.forward(
norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context,
sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt
......@@ -545,7 +545,7 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
debug("concat_q", concat_q);
debug("concat_k", concat_k);
debug("concat_v", concat_v);
nvtxRangePop();
}
......@@ -718,7 +718,16 @@ FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Devic
}
}
Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer) {
Tensor FluxModel::forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer) {
const int batch_size = hidden_states.shape[0];
const Tensor::ScalarType dtype = hidden_states.dtype();
const Device device = hidden_states.device();
......@@ -727,6 +736,8 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
const int img_tokens = hidden_states.shape[1];
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
Tensor concat;
......@@ -735,6 +746,14 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
if (size_t(layer) < transformer_blocks.size()) {
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) {
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
} else {
if (size_t(layer) == transformer_blocks.size()) {
// txt first, same as diffusers
......@@ -745,10 +764,21 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
}
hidden_states = concat;
encoder_hidden_states = {};
}
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
if (controlnet_single_block_samples.valid()) {
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
}
};
auto load = [&](int layer) {
......@@ -776,6 +806,50 @@ Tensor FluxModel::forward(Tensor hidden_states, Tensor encoder_hidden_states, Te
return hidden_states;
}
std::tuple<Tensor, Tensor> FluxModel::forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_block_samples;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
// block_index = layer % num_controlnet_single_block_samples
auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
}
return { hidden_states, encoder_hidden_states };
}
void FluxModel::setAttentionImpl(AttentionImpl impl) {
for (auto &&block : this->transformer_blocks) {
block->attnImpl = impl;
......
......@@ -61,7 +61,7 @@ private:
class Attention : public Module {
public:
static constexpr int POOL_SIZE = 128;
Attention(int num_heads, int dim_head, Device device);
Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);
......@@ -138,13 +138,30 @@ private:
class FluxModel : public Module {
public:
FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Tensor forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb_img, Tensor rotary_emb_context, Tensor rotary_emb_single, bool skip_first_layer = false);
Tensor forward(
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor rotary_emb_single,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples,
bool skip_first_layer = false);
std::tuple<Tensor, Tensor> forward_layer(
size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl);
public:
const Tensor::ScalarType dtype;
std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
......
......@@ -13,7 +13,7 @@ public:
this->device.type = this->tensor.is_cuda() ? Device::CUDA : Device::CPU;
this->device.idx = this->tensor.get_device();
}
virtual bool isAsyncBuffer() override {
virtual bool isAsyncBuffer() override {
// TODO: figure out how torch manages memory
return this->device.type == Device::CUDA;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment