FluxModel.cpp 33.3 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "flash_api.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
9
10
11
12
#include "activation.h"

#include <nvtx3/nvToolsExt.h>

#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
13
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
14
15
16
17



Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
muyangli's avatar
muyangli committed
18
19
    Tensor ff_output = fc2.forward_quant(
        std::get<GEMM_W4A4::QuantizedActivation>(fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2))
Zhekai Zhang's avatar
Zhekai Zhang committed
20
21
22
23
24
25
26
27
28
29
30
    );
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }


Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
31
32
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
33
34
35
36
37
38
39
40
41
42
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }


AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) :
    dim(dim),
    linear(dim, 3 * dim, true, dtype, device),
Hyunsung Lee's avatar
Hyunsung Lee committed
43
    norm(dim, 1e-6, false, dtype, device)
Zhekai Zhang's avatar
Zhekai Zhang committed
44
45
46
47
48
49
50
51
52
53
54
{
    registerChildren
        (linear, "linear")
        (norm, "norm")
    ;
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
55
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
56
57
58
59
60
61
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
62

muyangli's avatar
muyangli committed
63
    kernels::mul_add(norm_x, scale_msa, shift_msa);
Zhekai Zhang's avatar
Zhekai Zhang committed
64
65
66
    return Output{norm_x, gate_msa};
}

Hyunsung Lee's avatar
Hyunsung Lee committed
67
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
Zhekai Zhang's avatar
Zhekai Zhang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    dim(dim), pre_only(pre_only),
    linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
    norm(dim, 1e-6, false, dtype, device)
{
    registerChildren
        (linear, "linear")
        (norm, "norm")
    ;
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
86
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
87
88
89
90
91
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

muyangli's avatar
muyangli committed
92
        kernels::mul_add(norm_x, scale_msa, shift_msa);
Zhekai Zhang's avatar
Zhekai Zhang committed
93
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
94

Zhekai Zhang's avatar
Zhekai Zhang committed
95
96
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
97
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
98
99
100
101
102
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

muyangli's avatar
muyangli committed
103
        kernels::mul_add(norm_x, scale_msa, shift_msa);
Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
107
108
109
110
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}


Hyunsung Lee's avatar
Hyunsung Lee committed
111
Attention::Attention(int num_heads, int dim_head, Device device) :
112
    num_heads(num_heads), dim_head(dim_head), force_fp16(false)
Zhekai Zhang's avatar
Zhekai Zhang committed
113
114
115
116
117
118
119
120
{
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

    const Device device = qkv.device();
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
    Tensor q = reshaped.slice(2, 0, num_heads);
    Tensor k = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3);

    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, -1, -1, false
    ).front();

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
    
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
148
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
149
150
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
151
152
153
154
155
156
157
158
    assert(qkv.ndims() == 3);

    const Device device = qkv.device();
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
muyangli's avatar
muyangli committed
159
    const int pool_tokens = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3);    // [batch_size, 3, num_heads, poolTokens, dim_head]
        for (int i = 0; i < batch_size; i++) {
            Tensor pool_q = pool_qkv.slice(0, i, i+1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i+1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i+1);
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
181

muyangli's avatar
muyangli committed
182
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
        cu_seqlens_cpu = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

204
205
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
206
        kernels::cast(qkv, tmp);
207
208
209
210
211
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
    Tensor q = reshaped.slice(1, 0, num_heads);
    Tensor k = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v = reshaped.slice(1, num_heads * 2, num_heads * 3);

    spdlog::debug("q,k,v={}", q.shape.str());

    Tensor raw_attn_output = mha_fwd_block(
        q, k, v,
        cu_seqlens, cu_seqlens,
        POOL_SIZE, POOL_SIZE,
        headmask_type,
        {},
        blockmask,
        num_tokens,
        num_tokens,
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, false, -1, -1
    ).front();

235
236
237
238
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
239
        kernels::cast(raw_attn_output, tmp);
240
241
242
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
257
258
259
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
260
261
262
263
264
265
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
266
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
267
268
269
270
271
272
273
274
275
276
277
278
279
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

280
281
282
283
284
285
286
287
288
289
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

290
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
Hyunsung Lee's avatar
Hyunsung Lee committed
291
    dim(dim),
Zhekai Zhang's avatar
Zhekai Zhang committed
292
293
294
295
    dim_head(attention_head_dim / num_attention_heads),
    num_heads(num_attention_heads),
    mlp_hidden_dim(dim * mlp_ratio),
    norm(dim, dtype, device),
296
297
298
    mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
    mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device),
    qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
Zhekai Zhang's avatar
Zhekai Zhang committed
299
300
301
    norm_q(dim_head, 1e-6, false, dtype, device),
    norm_k(dim_head, 1e-6, false, dtype, device),
    attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
302
    out_proj(dim, dim, true, use_fp4, dtype, device)
Zhekai Zhang's avatar
Zhekai Zhang committed
303
304
305
306
307
308
309
310
{
    registerChildren
        (norm, "norm")
        (mlp_fc1, "mlp_fc1")
        (mlp_fc2, "mlp_fc2")
        (qkv_proj, "qkv_proj")
        (norm_q, "norm_q")
        (norm_k, "norm_k")
311
        (attn, "attn")
Zhekai Zhang's avatar
Zhekai Zhang committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        (out_proj, "out_proj")
    ;
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

    nvtxRangePushA("FluxSingleTransformerBlock");

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

329
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
330
331

    debug("rotary_emb", rotary_emb);
332
333
334
335
336
337
338
339
340

    if (attnImpl == AttentionImpl::FlashAttention2) {
        Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

        qkv_proj.forward(norm_hidden_states, qkv, {}, norm_q.weight, norm_k.weight, rotary_emb);
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
341

342
343
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        assert(batch_size == 1);

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

        Tensor q = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());

        qkv_proj.forward(norm_hidden_states, {}, {}, norm_q.weight, norm_k.weight, rotary_emb, q, k, v, num_tokens);

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head}, norm_hidden_states.scalar_type(), norm_hidden_states.device());

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

        attn_output = o.slice(1, 0, num_tokens);
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
369
370
    debug("raw_attn_output", attn_output);

Hyunsung Lee's avatar
Hyunsung Lee committed
371

372

Zhekai Zhang's avatar
Zhekai Zhang committed
373
374
375
376
377
378
    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
379
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
380
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
381

muyangli's avatar
muyangli committed
382
    kernels::mul_add(hidden_states, gate, residual);
Zhekai Zhang's avatar
Zhekai Zhang committed
383
384
385
386
387
388

    nvtxRangePop();

    return hidden_states;
}

Hyunsung Lee's avatar
Hyunsung Lee committed
389
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
Zhekai Zhang's avatar
Zhekai Zhang committed
390
391
392
393
394
395
    dim(dim),
    dim_head(attention_head_dim / num_attention_heads),
    num_heads(num_attention_heads),
    context_pre_only(context_pre_only),
    norm1(dim, false, dtype, device),
    norm1_context(dim, context_pre_only, dtype, device),
396
397
    qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
    qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device),
Zhekai Zhang's avatar
Zhekai Zhang committed
398
399
400
401
402
    norm_q(dim_head, 1e-6, false, dtype, device),
    norm_k(dim_head, 1e-6, false, dtype, device),
    norm_added_q(dim_head, 1e-6, false, dtype, device),
    norm_added_k(dim_head, 1e-6, false, dtype, device),
    attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
403
404
    out_proj(dim, dim, true, use_fp4, dtype, device),
    out_proj_context(dim, dim, true, use_fp4, dtype, device),
Zhekai Zhang's avatar
Zhekai Zhang committed
405
406
    norm2(dim, 1e-6, false, dtype, device),
    norm2_context(dim, 1e-6, false, dtype, device),
407
408
409
410
    mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device),
    mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
    mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
    mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device)
Zhekai Zhang's avatar
Zhekai Zhang committed
411
412
413
414
415
416
417
418
419
420
{
    registerChildren
        (norm1, "norm1")
        (norm1_context, "norm1_context")
        (qkv_proj, "qkv_proj")
        (qkv_proj_context, "qkv_proj_context")
        (norm_q, "norm_q")
        (norm_k, "norm_k")
        (norm_added_q, "norm_added_q")
        (norm_added_k, "norm_added_k")
421
        (attn, "attn")
Zhekai Zhang's avatar
Zhekai Zhang committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        (out_proj, "out_proj")
        (out_proj_context, "out_proj_context")
        (norm2, "norm2")
        (norm2_context, "norm2_context")
        (mlp_fc1, "mlp_fc1")
        (mlp_fc2, "mlp_fc2")
        (mlp_context_fc1, "mlp_context_fc1")
        (mlp_context_fc2, "mlp_context_fc2")
    ;
}


// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio) {
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

    nvtxRangePushA("JointTransformerBlock");

    nvtxRangePushA("AdaNorm");


    int num_tokens_img = hidden_states.shape[1];
446
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
447

Zhekai Zhang's avatar
Zhekai Zhang committed
448
449
450
451
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str());
452
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

    auto norm1_output = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

    nvtxRangePop();

    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
470

471
472
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
473

474
475
476
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
477

478
479
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
480

481
482
        {
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
483

484
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
485

486
487
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
            concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
488

489
490
            pool = blockSparse
                ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
Zhekai Zhang's avatar
Zhekai Zhang committed
491
                : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
492

493
494
495
496
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
                Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
497

498
499
500
501
502
503
                Tensor pool_qkv = pool.valid()
                    ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
                    : Tensor{};
                Tensor pool_qkv_context = pool.valid()
                    ? concat.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                    : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
504

505
506
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
507

508
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
509

510
511
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
512

513
514
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
515

516
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
517

518
519
520
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
521

522
523
            nvtxRangePop();
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
524

525
526
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
527

528
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
529

530
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
531

532
533
534
535
536
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
537

538
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
539

540
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
541

542
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
543

544
545
546
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
547

548
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
549

550
551
        {
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
552

553
554
555
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device());
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
556

557
558
559
560
561
562
563
564
            for (int i = 0; i < batch_size; i++) {
                // img first
                auto sliceImg = [&](Tensor x) {
                    return x.slice(0, i, i+1).slice(2, 0, num_tokens_img_pad);
                };
                auto sliceTxt = [&](Tensor x) {
                    return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
565

566
567
568
569
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
                    sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img
                );
Hyunsung Lee's avatar
Hyunsung Lee committed
570

571
572
573
574
575
                qkv_proj_context.forward(
                    norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context,
                    sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt
                );
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
576

577
578
579
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
580

581
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
582
583
        }

584
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head}, norm1_output.x.scalar_type(), norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
585

586
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
587

588
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
589

590
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
591

592
593
594
595
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
596
597
598
599
600
601
602
603

    debug("raw_attn_output", raw_attn_output);

    {
        nvtxRangePushA("o_proj");

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

604
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
605
606
607
608
609
610
611

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split = raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
            checkCUDA(cudaMemcpy2DAsync(
muyangli's avatar
muyangli committed
612
                raw_attn_output_split.data_ptr(),
Zhekai Zhang's avatar
Zhekai Zhang committed
613
614
                num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                raw_attn_output.data_ptr(),
615
                (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
Zhekai Zhang's avatar
Zhekai Zhang committed
616
617
                num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                batch_size,
muyangli's avatar
muyangli committed
618
                cudaMemcpyDeviceToDevice,
Zhekai Zhang's avatar
Zhekai Zhang committed
619
620
                stream));
        }
muyangli's avatar
muyangli committed
621

Zhekai Zhang's avatar
Zhekai Zhang committed
622
623
624
625
626
627
628
629

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output = forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
        debug("img.attn_output", attn_output);

#if 1
muyangli's avatar
muyangli committed
630
        kernels::mul_add(attn_output, gate_msa, hidden_states);
Zhekai Zhang's avatar
Zhekai Zhang committed
631
632
633
634
635
636
637
638
639
640
        hidden_states = std::move(attn_output);

        nvtxRangePop();
        nvtxRangePushA("MLP");

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
muyangli's avatar
muyangli committed
641
        kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
Zhekai Zhang's avatar
Zhekai Zhang committed
642
643
644
645
646
647
648
649
650
651
652
653

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
muyangli's avatar
muyangli committed
654
        kernels::mul_add(ff_output, gate_mlp, hidden_states);
Zhekai Zhang's avatar
Zhekai Zhang committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
        hidden_states = std::move(ff_output);

        nvtxRangePop();

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
        return { hidden_states, encoder_hidden_states };
    }

    {
        nvtxRangePushA("o_proj_context");

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
673
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt).reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
674
        } else {
675
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
676
            checkCUDA(cudaMemcpy2DAsync(
muyangli's avatar
muyangli committed
677
                raw_attn_output_split.data_ptr(),
678
679
680
681
                num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
                num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
Zhekai Zhang's avatar
Zhekai Zhang committed
682
                batch_size,
muyangli's avatar
muyangli committed
683
                cudaMemcpyDeviceToDevice,
Zhekai Zhang's avatar
Zhekai Zhang committed
684
685
                stream));
        }
muyangli's avatar
muyangli committed
686

Zhekai Zhang's avatar
Zhekai Zhang committed
687
688
689
690
691
692
693
694

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output = forward_fc(out_proj_context, raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
        debug("context.attn_output", attn_output);

#if 1
muyangli's avatar
muyangli committed
695
        kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
Zhekai Zhang's avatar
Zhekai Zhang committed
696
697
698
699
700
701
702
703
704
705
        encoder_hidden_states = std::move(attn_output);

        nvtxRangePop();
        nvtxRangePushA("MLP");

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
muyangli's avatar
muyangli committed
706
        kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
Zhekai Zhang's avatar
Zhekai Zhang committed
707
708
709
710
711

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
712

Zhekai Zhang's avatar
Zhekai Zhang committed
713
714
715
716
717
718
719
720

        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
        // Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
muyangli's avatar
muyangli committed
721
        kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
Zhekai Zhang's avatar
Zhekai Zhang committed
722
723
724
725
726
727
728
729
730
731
732
733
        encoder_hidden_states = std::move(ff_output);

        nvtxRangePop();

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

    nvtxRangePop();

    return { hidden_states, encoder_hidden_states };
}

734
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : dtype(dtype), offload(offload) {
Zhekai Zhang's avatar
Zhekai Zhang committed
735
    for (int i = 0; i < 19; i++) {
736
        transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
737
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
738
739
740
741
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
742
743
    }
    for (int i = 0; i < 38; i++) {
744
        single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, Device::cuda()));
Zhekai Zhang's avatar
Zhekai Zhang committed
745
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
746
747
748
749
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
750
751
752
    }
}

Hyunsung Lee's avatar
Hyunsung Lee committed
753
754
755
756
757
758
759
760
761
762
Tensor FluxModel::forward(
        Tensor hidden_states,
        Tensor encoder_hidden_states,
        Tensor temb,
        Tensor rotary_emb_img,
        Tensor rotary_emb_context,
        Tensor rotary_emb_single,
        Tensor controlnet_block_samples,
        Tensor controlnet_single_block_samples,
        bool skip_first_layer) {
Zhekai Zhang's avatar
Zhekai Zhang committed
763
764
765
766
767
768
769
    const int batch_size = hidden_states.shape[0];
    const Tensor::ScalarType dtype = hidden_states.dtype();
    const Device device = hidden_states.device();

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
770
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
771

muyangli's avatar
muyangli committed
772
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
773

muyangli's avatar
muyangli committed
774
    auto compute = [&](int layer) {
775
        if (skip_first_layer && size_t(layer) == 0) return;
muyangli's avatar
muyangli committed
776
777
778
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
779
            if (controlnet_block_samples.valid()) {
780
781
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
782
783
784
785
786
787
788
                int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
muyangli's avatar
muyangli committed
789
790
791
792
793
794
795
796
797
798
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states);
                    concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states);
                }
                hidden_states = concat;
                encoder_hidden_states = {};
Hyunsung Lee's avatar
Hyunsung Lee committed
799

muyangli's avatar
muyangli committed
800
801
802
803
            }

            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
804
            if (controlnet_single_block_samples.valid()) {
805
806
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
807
808
809
810
811
812
813
814
815
                int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
            }
muyangli's avatar
muyangli committed
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
839
840

    return hidden_states;
841
842
}

Hyunsung Lee's avatar
Hyunsung Lee committed
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
std::tuple<Tensor, Tensor> FluxModel::forward_layer(
        size_t layer,
        Tensor hidden_states,
        Tensor encoder_hidden_states,
        Tensor temb,
        Tensor rotary_emb_img,
        Tensor rotary_emb_context,
        Tensor controlnet_block_samples,
        Tensor controlnet_single_block_samples) {

    std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
        hidden_states,
        encoder_hidden_states,
        temb,
        rotary_emb_img,
        rotary_emb_context, 0.0f);

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
864
865
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
866
867
868
869
870
871
872
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
        int block_index = layer / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
873
874
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
875
876
877
878
879
880
881
882
883
884
885
886
887
        int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
        slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

    return { hidden_states, encoder_hidden_states };
}

888
889
890
891
892
893
894
895
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}