FluxModel.cpp 42 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "flash_api.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
8
#include "activation.h"
#include <nvtx3/nvToolsExt.h>

Muyang Li's avatar
Muyang Li committed
9
10
#include <pybind11/functional.h>

Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
14
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
15
16

Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Muyang Li's avatar
Muyang Li committed
17
18
    Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
        fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
22
23
24
25
26
27
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }

Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
28
29
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
30
31
32
33
34
35
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }

Muyang Li's avatar
Muyang Li committed
36
37
38
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device)
    : dim(dim), linear(dim, 3 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
39
40
41
42
43
44
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
45
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
46
47
48
49
50
51
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
52

53
54
    // kernels::mul_add(norm_x, scale_msa, shift_msa);
    kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
55
56
57
    return Output{norm_x, gate_msa};
}

Muyang Li's avatar
Muyang Li committed
58
59
60
61
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device)
    : dim(dim), pre_only(pre_only), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
      norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63
64
65
66
67
68
69
70
71
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
72
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
73
74
75
76
77
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

78
79
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
80
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
81

Zhekai Zhang's avatar
Zhekai Zhang committed
82
83
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
84
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
89
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

90
91
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
92
93
94
95
96
97
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}

Muyang Li's avatar
Muyang Li committed
98
99
Attention::Attention(int num_heads, int dim_head, Device device)
    : num_heads(num_heads), dim_head(dim_head), force_fp16(false) {
Zhekai Zhang's avatar
Zhekai Zhang committed
100
101
102
103
104
105
106
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

107
108
109
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
110
    const Device device  = qkv.device();
111
112
113
114
115
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
116
117
118
    Tensor q        = reshaped.slice(2, 0, num_heads);
    Tensor k        = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(2, num_heads * 2, num_heads * 3);
119

Muyang Li's avatar
Muyang Li committed
120
    Tensor raw_attn_output = mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false).front();
121
122
123
124
125

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
Muyang Li's avatar
Muyang Li committed
126

127
128
129
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
130
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
131
132
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
133
134
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
135
    const Device device  = qkv.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
136
137
138
139
140
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
Muyang Li's avatar
Muyang Li committed
141
    const int pool_tokens   = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
142
143
144
145
146
147
148
149
150
151
152
153
154

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
Muyang Li's avatar
Muyang Li committed
155
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
156
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
157
158
159
            Tensor pool_q = pool_qkv.slice(0, i, i + 1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i + 1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i + 1);
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
163

muyangli's avatar
muyangli committed
164
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
Muyang Li's avatar
Muyang Li committed
179
        cu_seqlens_cpu                        = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
Zhekai Zhang's avatar
Zhekai Zhang committed
180
181
182
183
184
185
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

186
187
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
188
        kernels::cast(qkv, tmp);
189
190
191
192
193
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
194
195
196
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
197
198
199
    Tensor q        = reshaped.slice(1, 0, num_heads);
    Tensor k        = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(1, num_heads * 2, num_heads * 3);
Zhekai Zhang's avatar
Zhekai Zhang committed
200
201
202

    spdlog::debug("q,k,v={}", q.shape.str());

Muyang Li's avatar
Muyang Li committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    Tensor raw_attn_output = mha_fwd_block(q,
                                           k,
                                           v,
                                           cu_seqlens,
                                           cu_seqlens,
                                           POOL_SIZE,
                                           POOL_SIZE,
                                           headmask_type,
                                           {},
                                           blockmask,
                                           num_tokens,
                                           num_tokens,
                                           0.0f,
                                           pow(q.shape[-1], (-0.5)),
                                           false,
                                           false,
                                           false,
                                           -1,
                                           -1)
                                 .front();
Zhekai Zhang's avatar
Zhekai Zhang committed
223

224
225
226
227
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
228
        kernels::cast(raw_attn_output, tmp);
229
230
231
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
246
247
248
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
249
250
251
252
253
254
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
255
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
256
257
258
259
260
261
262
263
264
265
266
267
268
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

269
270
271
272
273
274
275
276
277
278
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

Muyang Li's avatar
Muyang Li committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim,
                                                       int num_attention_heads,
                                                       int attention_head_dim,
                                                       int mlp_ratio,
                                                       bool use_fp4,
                                                       Tensor::ScalarType dtype,
                                                       Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      mlp_hidden_dim(dim * mlp_ratio), norm(dim, dtype, device),
      mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
      mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      norm_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device) {
    registerChildren(norm, "norm")(mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(qkv_proj, "qkv_proj")(norm_q, "norm_q")(
        norm_k, "norm_k")(attn, "attn")(out_proj, "out_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

    nvtxRangePushA("FluxSingleTransformerBlock");

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

310
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
311
312

    debug("rotary_emb", rotary_emb);
313
314

    if (attnImpl == AttentionImpl::FlashAttention2) {
Muyang Li's avatar
Muyang Li committed
315
316
        Tensor qkv = Tensor::allocate(
            {batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
317
318
319
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

320
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
321
322
323
324
325
326
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             qkv.slice(0, i, i + 1),
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb);
327
        }
328
329
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
330

331
332
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
333
334
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
335
        // assert(batch_size == 1);
336
337
338

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

Muyang Li's avatar
Muyang Li committed
339
340
341
342
343
344
        Tensor q = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
345

346
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
347
348
349
350
351
352
353
354
355
356
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             {},
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb,
                             q.slice(0, i, i + 1),
                             k.slice(0, i, i + 1),
                             v.slice(0, i, i + 1),
                             num_tokens);
357
        }
358
359
360
361
362

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

Muyang Li's avatar
Muyang Li committed
363
364
365
        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head},
                                    norm_hidden_states.scalar_type(),
                                    norm_hidden_states.device());
366
367
368

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

369
370
371
372
        if (batch_size == 1 || num_tokens_pad == num_tokens) {
            attn_output = o.slice(1, 0, num_tokens);
        } else {
            attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
Muyang Li's avatar
Muyang Li committed
373
374
375
376
377
378
379
380
            checkCUDA(cudaMemcpy2DAsync(attn_output.data_ptr(),
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        o.data_ptr(),
                                        o.stride(0) * o.scalar_size(),
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        batch_size,
                                        cudaMemcpyDeviceToDevice,
                                        getCurrentCUDAStream()));
381
        }
382
383
384
385
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
386
387
388
389
390
391
392
393
    debug("raw_attn_output", attn_output);

    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
394
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
395
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
396

397
398
    // kernels::mul_add(hidden_states, gate, residual);
    kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
399
400
401
402
403
404

    nvtxRangePop();

    return hidden_states;
}

Muyang Li's avatar
Muyang Li committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
JointTransformerBlock::JointTransformerBlock(int dim,
                                             int num_attention_heads,
                                             int attention_head_dim,
                                             bool context_pre_only,
                                             bool use_fp4,
                                             Tensor::ScalarType dtype,
                                             Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      context_pre_only(context_pre_only), norm1(dim, false, dtype, device),
      norm1_context(dim, context_pre_only, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
      norm_k(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
      norm_added_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
      norm2(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
      mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
      mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
      mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) {
    registerChildren(norm1, "norm1")(norm1_context, "norm1_context")(qkv_proj, "qkv_proj")(qkv_proj_context,
                                                                                           "qkv_proj_context")(
        norm_q, "norm_q")(norm_k, "norm_k")(norm_added_q, "norm_added_q")(norm_added_k, "norm_added_k")(attn, "attn")(
        out_proj, "out_proj")(out_proj_context, "out_proj_context")(norm2, "norm2")(norm2_context, "norm2_context")(
        mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(mlp_context_fc1, "mlp_context_fc1")(mlp_context_fc2, "mlp_context_fc2");
Zhekai Zhang's avatar
Zhekai Zhang committed
429
430
431
432
}

// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
Muyang Li's avatar
Muyang Li committed
433
434
435
436
437
438
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
                                                          Tensor encoder_hidden_states,
                                                          Tensor temb,
                                                          Tensor rotary_emb,
                                                          Tensor rotary_emb_context,
                                                          float sparsityRatio) {
Zhekai Zhang's avatar
Zhekai Zhang committed
439
440
441
442
443
444
445
446
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

    nvtxRangePushA("JointTransformerBlock");

    nvtxRangePushA("AdaNorm");

    int num_tokens_img = hidden_states.shape[1];
447
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
448

Zhekai Zhang's avatar
Zhekai Zhang committed
449
450
451
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

Muyang Li's avatar
Muyang Li committed
452
453
454
455
    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
456
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
457

Muyang Li's avatar
Muyang Li committed
458
    auto norm1_output         = norm1.forward(hidden_states, temb);
Zhekai Zhang's avatar
Zhekai Zhang committed
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

    nvtxRangePop();

    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
474

475
476
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
477

478
479
480
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
481

482
483
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
484

485
486
        {
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
487

488
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
489

490
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
Muyang Li's avatar
Muyang Li committed
491
492
493
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
494

Muyang Li's avatar
Muyang Li committed
495
496
497
498
            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
499

500
501
502
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Muyang Li's avatar
Muyang Li committed
503
504
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
505

Muyang Li's avatar
Muyang Li committed
506
507
                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
508
                Tensor pool_qkv_context = pool.valid()
Muyang Li's avatar
Muyang Li committed
509
510
511
512
513
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
514

515
516
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
517

518
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
519

Muyang Li's avatar
Muyang Li committed
520
521
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
522
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
523

524
525
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
526

527
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
528

Muyang Li's avatar
Muyang Li committed
529
530
531
532
533
534
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
535
536
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
537

538
539
            nvtxRangePop();
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
540

541
542
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
543

544
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
545

546
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
547

548
549
550
551
552
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
553

554
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
555

556
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
557

558
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
559

560
561
562
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
563

564
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
565

566
567
        {
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
568

Muyang Li's avatar
Muyang Li committed
569
570
571
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
572
573
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
574

575
576
            for (int i = 0; i < batch_size; i++) {
                // img first
Muyang Li's avatar
Muyang Li committed
577
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
578
                auto sliceTxt = [&](Tensor x) {
Muyang Li's avatar
Muyang Li committed
579
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
580
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
581

Muyang Li's avatar
Muyang Li committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
603
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
604

605
606
607
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
608

609
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
610
611
        }

Muyang Li's avatar
Muyang Li committed
612
613
614
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
615

616
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
617

618
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
619

620
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
621

Muyang Li's avatar
Muyang Li committed
622
623
        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
624
625
626
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
627
628
629
630
631
632
633
634

    debug("raw_attn_output", raw_attn_output);

    {
        nvtxRangePushA("o_proj");

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

635
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
636
637
638

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
639
640
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
641
        } else {
Muyang Li's avatar
Muyang Li committed
642
643
644
645
646
647
648
649
650
651
652
653
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
                                        cudaMemcpyDeviceToDevice,
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
654
        }
muyangli's avatar
muyangli committed
655

Zhekai Zhang's avatar
Zhekai Zhang committed
656
657
658
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
659
660
        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
661
662
663
        debug("img.attn_output", attn_output);

#if 1
664
665
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
666
667
668
669
670
671
672
673
674
675
        hidden_states = std::move(attn_output);

        nvtxRangePop();
        nvtxRangePushA("MLP");

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
676
677
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
678
679
680
681
682
683
684
685
686
687
688
689

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
690
691
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
692
693
694
695
696
697
698
699
        hidden_states = std::move(ff_output);

        nvtxRangePop();

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
Muyang Li's avatar
Muyang Li committed
700
        return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
701
702
703
704
705
706
707
708
709
    }

    {
        nvtxRangePushA("o_proj_context");

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
710
711
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
712
        } else {
Muyang Li's avatar
Muyang Li committed
713
714
715
716
717
718
719
720
721
722
723
724
725
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
                                        cudaMemcpyDeviceToDevice,
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
726
        }
muyangli's avatar
muyangli committed
727

Zhekai Zhang's avatar
Zhekai Zhang committed
728
729
730
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
731
732
733
        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
734
735
736
        debug("context.attn_output", attn_output);

#if 1
737
738
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
739
740
741
742
743
744
745
746
747
748
        encoder_hidden_states = std::move(attn_output);

        nvtxRangePop();
        nvtxRangePushA("MLP");

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
749
750
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
751
752
753
754
755

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
756

Zhekai Zhang's avatar
Zhekai Zhang committed
757
        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
Muyang Li's avatar
Muyang Li committed
758
759
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
Zhekai Zhang's avatar
Zhekai Zhang committed
760
761
762
763
764
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
765
766
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
767
768
769
770
771
772
773
774
775
        encoder_hidden_states = std::move(ff_output);

        nvtxRangePop();

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

    nvtxRangePop();

Muyang Li's avatar
Muyang Li committed
776
    return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
777
778
}

Muyang Li's avatar
Muyang Li committed
779
780
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
    : dtype(dtype), offload(offload) {
Zhekai Zhang's avatar
Zhekai Zhang committed
781
    for (int i = 0; i < 19; i++) {
Muyang Li's avatar
Muyang Li committed
782
783
        transformer_blocks.push_back(
            std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
784
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
785
786
787
788
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
789
790
    }
    for (int i = 0; i < 38; i++) {
Muyang Li's avatar
Muyang Li committed
791
792
        single_transformer_blocks.push_back(
            std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
793
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
794
795
796
797
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
798
799
800
    }
}

Muyang Li's avatar
Muyang Li committed
801
802
803
804
805
806
807
808
809
810
Tensor FluxModel::forward(Tensor hidden_states,
                          Tensor encoder_hidden_states,
                          Tensor temb,
                          Tensor rotary_emb_img,
                          Tensor rotary_emb_context,
                          Tensor rotary_emb_single,
                          Tensor controlnet_block_samples,
                          Tensor controlnet_single_block_samples,
                          bool skip_first_layer) {
    const int batch_size           = hidden_states.shape[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
811
    const Tensor::ScalarType dtype = hidden_states.dtype();
Muyang Li's avatar
Muyang Li committed
812
    const Device device            = hidden_states.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
813
814
815
816

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
817
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
818

muyangli's avatar
muyangli committed
819
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
820

muyangli's avatar
muyangli committed
821
    auto compute = [&](int layer) {
Muyang Li's avatar
Muyang Li committed
822
823
        if (skip_first_layer && size_t(layer) == 0)
            return;
muyangli's avatar
muyangli committed
824
825
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
Muyang Li's avatar
Muyang Li committed
826
827
            std::tie(hidden_states, encoder_hidden_states) =
                block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
828
            if (controlnet_block_samples.valid()) {
829
830
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
831
832
                int interval_control =
                    ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
833
834
835
836
837
838
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
K's avatar
K committed
839
840
841
842
            if (residual_callback && layer % 2 == 0) {
                Tensor cpu_input = hidden_states.copy(Device::cpu());
                pybind11::gil_scoped_acquire gil;
                Tensor cpu_output = residual_callback(cpu_input);
Muyang Li's avatar
Muyang Li committed
843
844
                Tensor residual   = cpu_output.copy(Device::cuda());
                hidden_states     = kernels::add(hidden_states, residual);
K's avatar
K committed
845
            }
muyangli's avatar
muyangli committed
846
847
848
849
850
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
851
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
Muyang Li's avatar
Muyang Li committed
852
853
854
                    concat.slice(0, i, i + 1)
                        .slice(1, txt_tokens, txt_tokens + img_tokens)
                        .copy_(hidden_states.slice(0, i, i + 1));
muyangli's avatar
muyangli committed
855
                }
Muyang Li's avatar
Muyang Li committed
856
                hidden_states         = concat;
muyangli's avatar
muyangli committed
857
858
859
                encoder_hidden_states = {};
            }

Muyang Li's avatar
Muyang Li committed
860
            auto &block   = single_transformer_blocks.at(layer - transformer_blocks.size());
muyangli's avatar
muyangli committed
861
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
862
            if (controlnet_single_block_samples.valid()) {
863
864
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
865
866
                int interval_control =
                    ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
867
868
869
870
871
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
872
                slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
873
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Muyang Li's avatar
Muyang Li committed
874
            }
K's avatar
K committed
875
876
877
            size_t local_layer_idx = layer - transformer_blocks.size();
            if (residual_callback && local_layer_idx % 4 == 0) {
                Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
878
                Tensor cpu_input      = callback_input.copy(Device::cpu());
K's avatar
K committed
879
880
                pybind11::gil_scoped_acquire gil;
                Tensor cpu_output = residual_callback(cpu_input);
Muyang Li's avatar
Muyang Li committed
881
882
883
                Tensor residual   = cpu_output.copy(Device::cuda());
                auto slice        = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice             = kernels::add(slice, residual);
K's avatar
K committed
884
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Hyunsung Lee's avatar
Hyunsung Lee committed
885
            }
muyangli's avatar
muyangli committed
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
909
910

    return hidden_states;
911
912
}

Muyang Li's avatar
Muyang Li committed
913
914
915
916
917
918
919
920
921
std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
                                                    Tensor hidden_states,
                                                    Tensor encoder_hidden_states,
                                                    Tensor temb,
                                                    Tensor rotary_emb_img,
                                                    Tensor rotary_emb_context,
                                                    Tensor controlnet_block_samples,
                                                    Tensor controlnet_single_block_samples) {

922
923
924
925
926
927
928
929
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
930
    if (layer < transformer_blocks.size()) {
931
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
Muyang Li's avatar
Muyang Li committed
932
933
934
935
936
            hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    } else {
        std::tie(hidden_states, encoder_hidden_states) =
            transformer_blocks.at(layer - transformer_blocks.size())
                ->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
937
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
938
939
940
941
942

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
943
944
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
945
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Muyang Li's avatar
Muyang Li committed
946
        int block_index      = layer / interval_control;
Hyunsung Lee's avatar
Hyunsung Lee committed
947
948
949
950
951
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
952
953
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
954
955
        int interval_control =
            ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
956
957
958
959
960
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
961
        slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
962
963
964
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

965
966
967
968
969
970
971
972
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->releaseLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
973
    return {hidden_states, encoder_hidden_states};
Hyunsung Lee's avatar
Hyunsung Lee committed
974
975
}

976
977
978
979
980
981
982
983
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}
Muyang Li's avatar
Muyang Li committed
984
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor &)> cb) {
K's avatar
K committed
985
    residual_callback = std::move(cb);
Muyang Li's avatar
Muyang Li committed
986
}