Unverified Commit b87f4d3e authored by shencuifeng's avatar shencuifeng Committed by GitHub
Browse files

fix: support for transformer block indices > 19 in forward_layer (FluxModel.cpp) (#308)

* fix forward_layer in FluxModel.cpp

* Update FluxModel.cpp
parent f82dc8fb
......@@ -883,12 +883,22 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
if (layer < transformer_blocks.size()){
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
}
else {
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer - transformer_blocks.size())->forward(
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_context, 0.0f);
}
const int txt_tokens = encoder_hidden_states.shape[1];
const int img_tokens = hidden_states.shape[1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment