Commit bf3669dd authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] fix runtime error when controlnet is not present

parent 235238bd
......@@ -736,8 +736,6 @@ Tensor FluxModel::forward(
const int img_tokens = hidden_states.shape[1];
const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
Tensor concat;
......@@ -747,6 +745,8 @@ Tensor FluxModel::forward(
auto &block = transformer_blocks.at(layer);
std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) {
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
......@@ -770,6 +770,8 @@ Tensor FluxModel::forward(
auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
if (controlnet_single_block_samples.valid()) {
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
......@@ -826,10 +828,9 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
// Xlabs ControlNet
......@@ -837,6 +838,8 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
} else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];
int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
int block_index = (layer - transformer_blocks.size()) / interval_control;
// Xlabs ControlNet
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment