FluxModel.cpp 36.8 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "flash_api.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
#include "activation.h"
#include <nvtx3/nvToolsExt.h>

K's avatar
K committed
9
10
#include <pybind11/functional.h>        
           
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
14
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
15
16
17
18



Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
muyangli's avatar
muyangli committed
19
20
    Tensor ff_output = fc2.forward_quant(
        std::get<GEMM_W4A4::QuantizedActivation>(fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2))
Zhekai Zhang's avatar
Zhekai Zhang committed
21
22
23
24
25
26
27
28
29
30
31
    );
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }


Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
32
33
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
34
35
36
37
38
39
40
41
42
43
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }


AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device) :
    dim(dim),
    linear(dim, 3 * dim, true, dtype, device),
Hyunsung Lee's avatar
Hyunsung Lee committed
44
    norm(dim, 1e-6, false, dtype, device)
Zhekai Zhang's avatar
Zhekai Zhang committed
45
46
47
48
49
50
51
52
53
54
55
{
    registerChildren
        (linear, "linear")
        (norm, "norm")
    ;
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
56
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
57
58
59
60
61
62
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
63

64
65
    // kernels::mul_add(norm_x, scale_msa, shift_msa);
    kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
66
67
68
    return Output{norm_x, gate_msa};
}

Hyunsung Lee's avatar
Hyunsung Lee committed
69
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device) :
Zhekai Zhang's avatar
Zhekai Zhang committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    dim(dim), pre_only(pre_only),
    linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
    norm(dim, 1e-6, false, dtype, device)
{
    registerChildren
        (linear, "linear")
        (norm, "norm")
    ;
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
88
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
89
90
91
92
93
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

94
95
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
96
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
97

Zhekai Zhang's avatar
Zhekai Zhang committed
98
99
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
100
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
101
102
103
104
105
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

106
107
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
108
109
110
111
112
113
114
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}


Hyunsung Lee's avatar
Hyunsung Lee committed
115
Attention::Attention(int num_heads, int dim_head, Device device) :
116
    num_heads(num_heads), dim_head(dim_head), force_fp16(false)
Zhekai Zhang's avatar
Zhekai Zhang committed
117
118
119
120
121
122
123
124
{
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

    const Device device = qkv.device();
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
    Tensor q = reshaped.slice(2, 0, num_heads);
    Tensor k = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v = reshaped.slice(2, num_heads * 2, num_heads * 3);

    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, -1, -1, false
    ).front();

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
    
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
152
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
153
154
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
155
156
157
158
159
160
161
162
    assert(qkv.ndims() == 3);

    const Device device = qkv.device();
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
muyangli's avatar
muyangli committed
163
    const int pool_tokens = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3);    // [batch_size, 3, num_heads, poolTokens, dim_head]
        for (int i = 0; i < batch_size; i++) {
            Tensor pool_q = pool_qkv.slice(0, i, i+1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i+1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i+1);
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
185

muyangli's avatar
muyangli committed
186
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
        cu_seqlens_cpu = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

208
209
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
210
        kernels::cast(qkv, tmp);
211
212
213
214
215
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
    Tensor q = reshaped.slice(1, 0, num_heads);
    Tensor k = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v = reshaped.slice(1, num_heads * 2, num_heads * 3);

    spdlog::debug("q,k,v={}", q.shape.str());

    Tensor raw_attn_output = mha_fwd_block(
        q, k, v,
        cu_seqlens, cu_seqlens,
        POOL_SIZE, POOL_SIZE,
        headmask_type,
        {},
        blockmask,
        num_tokens,
        num_tokens,
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, false, -1, -1
    ).front();

239
240
241
242
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
243
        kernels::cast(raw_attn_output, tmp);
244
245
246
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
261
262
263
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
264
265
266
267
268
269
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
270
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
271
272
273
274
275
276
277
278
279
280
281
282
283
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

284
285
286
287
288
289
290
291
292
293
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

294
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, bool use_fp4, Tensor::ScalarType dtype, Device device) :
Hyunsung Lee's avatar
Hyunsung Lee committed
295
    dim(dim),
Zhekai Zhang's avatar
Zhekai Zhang committed
296
297
298
299
    dim_head(attention_head_dim / num_attention_heads),
    num_heads(num_attention_heads),
    mlp_hidden_dim(dim * mlp_ratio),
    norm(dim, dtype, device),
300
301
302
    mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
    mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device),
    qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
Zhekai Zhang's avatar
Zhekai Zhang committed
303
304
305
    norm_q(dim_head, 1e-6, false, dtype, device),
    norm_k(dim_head, 1e-6, false, dtype, device),
    attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
306
    out_proj(dim, dim, true, use_fp4, dtype, device)
Zhekai Zhang's avatar
Zhekai Zhang committed
307
308
309
310
311
312
313
314
{
    registerChildren
        (norm, "norm")
        (mlp_fc1, "mlp_fc1")
        (mlp_fc2, "mlp_fc2")
        (qkv_proj, "qkv_proj")
        (norm_q, "norm_q")
        (norm_k, "norm_k")
315
        (attn, "attn")
Zhekai Zhang's avatar
Zhekai Zhang committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        (out_proj, "out_proj")
    ;
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

    nvtxRangePushA("FluxSingleTransformerBlock");

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

333
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
334
335

    debug("rotary_emb", rotary_emb);
336
337
338
339
340
341

    if (attnImpl == AttentionImpl::FlashAttention2) {
        Tensor qkv = Tensor::allocate({batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

342
343
344
        for (int i = 0; i < batch_size; i++) {
            qkv_proj.forward(norm_hidden_states.slice(0, i, i+1), qkv.slice(0, i, i+1), {}, norm_q.weight, norm_k.weight, rotary_emb);
        }
345
346
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
347

348
349
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
350
351
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
352
        // assert(batch_size == 1);
353
354
355
356
357
358
359

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

        Tensor q = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate({batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());

360
361
362
363
364
365
366
367
        for (int i = 0; i < batch_size; i++) {
            qkv_proj.forward(
                norm_hidden_states.slice(0, i, i+1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb, 
                q.slice(0, i, i+1), 
                k.slice(0, i, i+1), 
                v.slice(0, i, i+1), 
                num_tokens);
        }
368
369
370
371
372
373
374
375
376

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head}, norm_hidden_states.scalar_type(), norm_hidden_states.device());

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
        if (batch_size == 1 || num_tokens_pad == num_tokens) {
            attn_output = o.slice(1, 0, num_tokens);
        } else {
            attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
            checkCUDA(cudaMemcpy2DAsync(
                attn_output.data_ptr(),
                attn_output.stride(0) * attn_output.scalar_size(),
                o.data_ptr(),
                o.stride(0) * o.scalar_size(),
                attn_output.stride(0) * attn_output.scalar_size(),
                batch_size,
                cudaMemcpyDeviceToDevice,
                getCurrentCUDAStream()
            ));
        }
392
393
394
395
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
396
397
    debug("raw_attn_output", attn_output);

Hyunsung Lee's avatar
Hyunsung Lee committed
398

399

Zhekai Zhang's avatar
Zhekai Zhang committed
400
401
402
403
404
405
    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
406
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
407
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
408

409
410
    // kernels::mul_add(hidden_states, gate, residual);
    kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
411
412
413
414
415
416

    nvtxRangePop();

    return hidden_states;
}

Hyunsung Lee's avatar
Hyunsung Lee committed
417
JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, bool context_pre_only, bool use_fp4, Tensor::ScalarType dtype, Device device) :
Zhekai Zhang's avatar
Zhekai Zhang committed
418
419
420
421
422
423
    dim(dim),
    dim_head(attention_head_dim / num_attention_heads),
    num_heads(num_attention_heads),
    context_pre_only(context_pre_only),
    norm1(dim, false, dtype, device),
    norm1_context(dim, context_pre_only, dtype, device),
424
425
    qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
    qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device),
Zhekai Zhang's avatar
Zhekai Zhang committed
426
427
428
429
430
    norm_q(dim_head, 1e-6, false, dtype, device),
    norm_k(dim_head, 1e-6, false, dtype, device),
    norm_added_q(dim_head, 1e-6, false, dtype, device),
    norm_added_k(dim_head, 1e-6, false, dtype, device),
    attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
431
432
    out_proj(dim, dim, true, use_fp4, dtype, device),
    out_proj_context(dim, dim, true, use_fp4, dtype, device),
Zhekai Zhang's avatar
Zhekai Zhang committed
433
434
    norm2(dim, 1e-6, false, dtype, device),
    norm2_context(dim, 1e-6, false, dtype, device),
435
436
437
438
    mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device),
    mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
    mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
    mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device)
Zhekai Zhang's avatar
Zhekai Zhang committed
439
440
441
442
443
444
445
446
447
448
{
    registerChildren
        (norm1, "norm1")
        (norm1_context, "norm1_context")
        (qkv_proj, "qkv_proj")
        (qkv_proj_context, "qkv_proj_context")
        (norm_q, "norm_q")
        (norm_k, "norm_k")
        (norm_added_q, "norm_added_q")
        (norm_added_k, "norm_added_k")
449
        (attn, "attn")
Zhekai Zhang's avatar
Zhekai Zhang committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        (out_proj, "out_proj")
        (out_proj_context, "out_proj_context")
        (norm2, "norm2")
        (norm2_context, "norm2_context")
        (mlp_fc1, "mlp_fc1")
        (mlp_fc2, "mlp_fc2")
        (mlp_context_fc1, "mlp_context_fc1")
        (mlp_context_fc2, "mlp_context_fc2")
    ;
}


// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states, Tensor encoder_hidden_states, Tensor temb, Tensor rotary_emb, Tensor rotary_emb_context, float sparsityRatio) {
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

    nvtxRangePushA("JointTransformerBlock");

    nvtxRangePushA("AdaNorm");


    int num_tokens_img = hidden_states.shape[1];
474
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
475

Zhekai Zhang's avatar
Zhekai Zhang committed
476
477
478
479
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}", hidden_states.shape.str(), encoder_hidden_states.shape.str(), temb.shape.str());
480
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497

    auto norm1_output = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

    nvtxRangePop();

    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
498

499
500
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
501

502
503
504
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
505

506
507
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
508

509
510
        {
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
511

512
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
513

514
515
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
            concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
516

517
518
            pool = blockSparse
                ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
Zhekai Zhang's avatar
Zhekai Zhang committed
519
                : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
520

521
522
523
524
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
                Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
525

526
527
528
529
                Tensor pool_qkv = pool.valid()
                    ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE)
                    : Tensor{};
                Tensor pool_qkv_context = pool.valid()
530
                    ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
531
                    : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
532

533
534
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
535

536
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
537

538
539
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
540

541
542
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
543

544
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
545

546
547
548
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context, pool_qkv_context, norm_added_q.weight, norm_added_k.weight, rotary_emb_context);
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
549

550
551
            nvtxRangePop();
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
552

553
554
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
555

556
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
557

558
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
559

560
561
562
563
564
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
565

566
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
567

568
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
569

570
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
571

572
573
574
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
575

576
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
577

578
579
        {
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
580

581
582
583
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head}, Tensor::FP16, norm1_output.x.device());
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
584

585
586
587
588
589
590
591
592
            for (int i = 0; i < batch_size; i++) {
                // img first
                auto sliceImg = [&](Tensor x) {
                    return x.slice(0, i, i+1).slice(2, 0, num_tokens_img_pad);
                };
                auto sliceTxt = [&](Tensor x) {
                    return x.slice(0, i, i+1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
593

594
595
596
597
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), {}, {}, norm_q.weight, norm_k.weight, rotary_emb,
                    sliceImg(concat_q), sliceImg(concat_k), sliceImg(concat_v), num_tokens_img
                );
Hyunsung Lee's avatar
Hyunsung Lee committed
598

599
600
601
602
603
                qkv_proj_context.forward(
                    norm1_context_output.x.slice(0, i, i + 1), {}, {}, norm_added_q.weight, norm_added_k.weight, rotary_emb_context,
                    sliceTxt(concat_q), sliceTxt(concat_k), sliceTxt(concat_v), num_tokens_txt
                );
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
604

605
606
607
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
608

609
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
610
611
        }

612
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head}, norm1_output.x.scalar_type(), norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
613

614
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
615

616
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
617

618
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
619

620
621
622
623
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
624
625
626
627
628
629
630
631

    debug("raw_attn_output", raw_attn_output);

    {
        nvtxRangePushA("o_proj");

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

632
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
633
634
635
636
637
638
639

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split = raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
            checkCUDA(cudaMemcpy2DAsync(
muyangli's avatar
muyangli committed
640
                raw_attn_output_split.data_ptr(),
Zhekai Zhang's avatar
Zhekai Zhang committed
641
642
                num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                raw_attn_output.data_ptr(),
643
                (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
Zhekai Zhang's avatar
Zhekai Zhang committed
644
645
                num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                batch_size,
muyangli's avatar
muyangli committed
646
                cudaMemcpyDeviceToDevice,
Zhekai Zhang's avatar
Zhekai Zhang committed
647
648
                stream));
        }
muyangli's avatar
muyangli committed
649

Zhekai Zhang's avatar
Zhekai Zhang committed
650
651
652
653
654
655
656
657

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output = forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
        debug("img.attn_output", attn_output);

#if 1
658
659
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
660
661
662
663
664
665
666
667
668
669
        hidden_states = std::move(attn_output);

        nvtxRangePop();
        nvtxRangePushA("MLP");

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
670
671
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
672
673
674
675
676
677
678
679
680
681
682
683

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
684
685
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
        hidden_states = std::move(ff_output);

        nvtxRangePop();

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
        return { hidden_states, encoder_hidden_states };
    }

    {
        nvtxRangePushA("o_proj_context");

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
704
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt).reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
705
        } else {
706
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head}, raw_attn_output.scalar_type(), raw_attn_output.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
707
            checkCUDA(cudaMemcpy2DAsync(
muyangli's avatar
muyangli committed
708
                raw_attn_output_split.data_ptr(),
709
710
711
712
                num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head * raw_attn_output.scalar_size(),
                num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
Zhekai Zhang's avatar
Zhekai Zhang committed
713
                batch_size,
muyangli's avatar
muyangli committed
714
                cudaMemcpyDeviceToDevice,
Zhekai Zhang's avatar
Zhekai Zhang committed
715
716
                stream));
        }
muyangli's avatar
muyangli committed
717

Zhekai Zhang's avatar
Zhekai Zhang committed
718
719
720
721
722
723
724
725

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output = forward_fc(out_proj_context, raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
        debug("context.attn_output", attn_output);

#if 1
726
727
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
728
729
730
731
732
733
734
735
736
737
        encoder_hidden_states = std::move(attn_output);

        nvtxRangePop();
        nvtxRangePushA("MLP");

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
738
739
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
740
741
742
743
744

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
745

Zhekai Zhang's avatar
Zhekai Zhang committed
746
747
748
749
750
751
752
753

        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
        // Tensor ff_output = mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
754
755
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
756
757
758
759
760
761
762
763
764
765
766
767
        encoder_hidden_states = std::move(ff_output);

        nvtxRangePop();

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

    nvtxRangePop();

    return { hidden_states, encoder_hidden_states };
}

768
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device) : dtype(dtype), offload(offload) {
Zhekai Zhang's avatar
Zhekai Zhang committed
769
    for (int i = 0; i < 19; i++) {
770
        transformer_blocks.push_back(std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
771
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
772
773
774
775
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
776
777
    }
    for (int i = 0; i < 38; i++) {
778
        single_transformer_blocks.push_back(std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
779
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
780
781
782
783
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
784
785
786
    }
}

Hyunsung Lee's avatar
Hyunsung Lee committed
787
788
789
790
791
792
793
794
795
796
Tensor FluxModel::forward(
        Tensor hidden_states,
        Tensor encoder_hidden_states,
        Tensor temb,
        Tensor rotary_emb_img,
        Tensor rotary_emb_context,
        Tensor rotary_emb_single,
        Tensor controlnet_block_samples,
        Tensor controlnet_single_block_samples,
        bool skip_first_layer) {
Zhekai Zhang's avatar
Zhekai Zhang committed
797
798
799
800
801
802
803
    const int batch_size = hidden_states.shape[0];
    const Tensor::ScalarType dtype = hidden_states.dtype();
    const Device device = hidden_states.device();

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
804
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
805

muyangli's avatar
muyangli committed
806
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
807

muyangli's avatar
muyangli committed
808
    auto compute = [&](int layer) {
809
        if (skip_first_layer && size_t(layer) == 0) return;
muyangli's avatar
muyangli committed
810
811
812
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            std::tie(hidden_states, encoder_hidden_states) = block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
813
            if (controlnet_block_samples.valid()) {
814
815
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
816
817
818
819
820
821
822
                int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
K's avatar
K committed
823
824
825
826
827
828
829
            if (residual_callback && layer % 2 == 0) {
                Tensor cpu_input = hidden_states.copy(Device::cpu());
                pybind11::gil_scoped_acquire gil;
                Tensor cpu_output = residual_callback(cpu_input);
                Tensor residual = cpu_output.copy(Device::cuda());
                hidden_states = kernels::add(hidden_states, residual);
            }
muyangli's avatar
muyangli committed
830
831
832
833
834
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
835
836
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
                    concat.slice(0, i, i + 1).slice(1, txt_tokens, txt_tokens + img_tokens).copy_(hidden_states.slice(0, i, i + 1));
muyangli's avatar
muyangli committed
837
838
839
                }
                hidden_states = concat;
                encoder_hidden_states = {};
Hyunsung Lee's avatar
Hyunsung Lee committed
840

muyangli's avatar
muyangli committed
841
842
843
844
            }

            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
845
            if (controlnet_single_block_samples.valid()) {
846
847
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
848
849
850
851
852
853
854
855
                int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
K's avatar
K committed
856
857
858
859
860
861
862
863
864
865
866
            }   
            size_t local_layer_idx = layer - transformer_blocks.size();
            if (residual_callback && local_layer_idx % 4 == 0) {
                Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                Tensor cpu_input = callback_input.copy(Device::cpu());
                pybind11::gil_scoped_acquire gil;
                Tensor cpu_output = residual_callback(cpu_input);
                Tensor residual = cpu_output.copy(Device::cuda());
                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice = kernels::add(slice, residual);
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Hyunsung Lee's avatar
Hyunsung Lee committed
867
            }
muyangli's avatar
muyangli committed
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
891
892

    return hidden_states;
893
894
}

Hyunsung Lee's avatar
Hyunsung Lee committed
895
896
897
898
899
900
901
902
903
904
std::tuple<Tensor, Tensor> FluxModel::forward_layer(
        size_t layer,
        Tensor hidden_states,
        Tensor encoder_hidden_states,
        Tensor temb,
        Tensor rotary_emb_img,
        Tensor rotary_emb_context,
        Tensor controlnet_block_samples,
        Tensor controlnet_single_block_samples) {

905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
    if (layer < transformer_blocks.size()){
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
            hidden_states,
            encoder_hidden_states,
            temb,
            rotary_emb_img,
            rotary_emb_context, 0.0f);
    }
    else {
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer - transformer_blocks.size())->forward(
            hidden_states,
            encoder_hidden_states,
            temb,
            rotary_emb_img,
            rotary_emb_context, 0.0f);
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
921
922
923
924
925

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
926
927
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
928
929
930
931
932
933
934
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
        int block_index = layer / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
935
936
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
937
938
939
940
941
942
943
944
945
946
947
948
949
        int interval_control = ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
        slice = kernels::add(slice, controlnet_single_block_samples[block_index]);
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

    return { hidden_states, encoder_hidden_states };
}

950
951
952
953
954
955
956
957
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}
K's avatar
K committed
958
959
960
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor&)> cb) {
    residual_callback = std::move(cb);
}