"vscode:/vscode.git/clone" did not exist on "2e52e96317554130cb76fd31b6735d7ed225e024"
FluxModel.cpp 63.1 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "flash_api.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
6
#include "activation.h"
limm's avatar
limm committed
7
8
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
9

Muyang Li's avatar
Muyang Li committed
10
11
#include <pybind11/functional.h>

Zhekai Zhang's avatar
Zhekai Zhang committed
12
13
14
#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
15
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
16
17

Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Muyang Li's avatar
Muyang Li committed
18
19
    Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
        fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
Zhekai Zhang's avatar
Zhekai Zhang committed
20
21
22
23
24
25
26
27
28
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }

Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
29
30
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
31
32
33
34
35
36
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }

Muyang Li's avatar
Muyang Li committed
37
38
39
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device)
    : dim(dim), linear(dim, 3 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
40
41
42
43
44
45
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
46
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
47
48
49
50
51
52
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
53

54
55
    // kernels::mul_add(norm_x, scale_msa, shift_msa);
    kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
56
57
58
    return Output{norm_x, gate_msa};
}

Muyang Li's avatar
Muyang Li committed
59
60
61
62
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device)
    : dim(dim), pre_only(pre_only), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
      norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64
65
66
67
68
69
70
71
72
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
73
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
74
75
76
77
78
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

79
80
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
81
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
82

Zhekai Zhang's avatar
Zhekai Zhang committed
83
84
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
85
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
86
87
88
89
90
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

91
92
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
93
94
95
96
97
98
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}

Muyang Li's avatar
Muyang Li committed
99
100
Attention::Attention(int num_heads, int dim_head, Device device)
    : num_heads(num_heads), dim_head(dim_head), force_fp16(false) {
Zhekai Zhang's avatar
Zhekai Zhang committed
101
102
103
104
105
106
107
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

108
109
110
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
111
    const Device device  = qkv.device();
112
113
114
115
116
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
117
118
119
    Tensor q        = reshaped.slice(2, 0, num_heads);
    Tensor k        = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(2, num_heads * 2, num_heads * 3);
120

Muyang Li's avatar
Muyang Li committed
121
    Tensor raw_attn_output = mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false).front();
122
123
124
125
126

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
Muyang Li's avatar
Muyang Li committed
127

128
129
130
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
131
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
132
133
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
134
135
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
136
    const Device device  = qkv.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
137
138
139
140
141
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
Muyang Li's avatar
Muyang Li committed
142
    const int pool_tokens   = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
143
144
145
146
147
148
149
150
151
152
153
154
155

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
Muyang Li's avatar
Muyang Li committed
156
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
157
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
158
159
160
            Tensor pool_q = pool_qkv.slice(0, i, i + 1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i + 1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i + 1);
Zhekai Zhang's avatar
Zhekai Zhang committed
161
162
163
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
164

muyangli's avatar
muyangli committed
165
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
Muyang Li's avatar
Muyang Li committed
180
        cu_seqlens_cpu                        = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
Zhekai Zhang's avatar
Zhekai Zhang committed
181
182
183
184
185
186
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

187
188
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
189
        kernels::cast(qkv, tmp);
190
191
192
193
194
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
195
196
197
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
198
199
200
    Tensor q        = reshaped.slice(1, 0, num_heads);
    Tensor k        = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(1, num_heads * 2, num_heads * 3);
Zhekai Zhang's avatar
Zhekai Zhang committed
201
202
203

    spdlog::debug("q,k,v={}", q.shape.str());

Muyang Li's avatar
Muyang Li committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    Tensor raw_attn_output = mha_fwd_block(q,
                                           k,
                                           v,
                                           cu_seqlens,
                                           cu_seqlens,
                                           POOL_SIZE,
                                           POOL_SIZE,
                                           headmask_type,
                                           {},
                                           blockmask,
                                           num_tokens,
                                           num_tokens,
                                           0.0f,
                                           pow(q.shape[-1], (-0.5)),
                                           false,
                                           false,
                                           false,
                                           -1,
                                           -1)
                                 .front();
Zhekai Zhang's avatar
Zhekai Zhang committed
224

225
226
227
228
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
229
        kernels::cast(raw_attn_output, tmp);
230
231
232
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
247
248
249
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
250
251
252
253
254
255
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
256
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
257
258
259
260
261
262
263
264
265
266
267
268
269
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

270
271
272
273
274
275
276
277
278
279
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

Muyang Li's avatar
Muyang Li committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim,
                                                       int num_attention_heads,
                                                       int attention_head_dim,
                                                       int mlp_ratio,
                                                       bool use_fp4,
                                                       Tensor::ScalarType dtype,
                                                       Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      mlp_hidden_dim(dim * mlp_ratio), norm(dim, dtype, device),
      mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
      mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      norm_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device) {
    registerChildren(norm, "norm")(mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(qkv_proj, "qkv_proj")(norm_q, "norm_q")(
        norm_k, "norm_k")(attn, "attn")(out_proj, "out_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
296
297
298
299
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

fengzch-das's avatar
fengzch-das committed
300
    nvtxRangePushA("FluxSingleTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
301
302
303
304
305
306
307
308
309
310

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

311
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
312
313

    debug("rotary_emb", rotary_emb);
314
315

    if (attnImpl == AttentionImpl::FlashAttention2) {
Muyang Li's avatar
Muyang Li committed
316
317
        Tensor qkv = Tensor::allocate(
            {batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
318
319
320
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

321
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
322
323
324
325
326
327
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             qkv.slice(0, i, i + 1),
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb);
328
        }
329
330
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
331

332
333
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
334
335
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
336
        // assert(batch_size == 1);
337
338
339

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

Muyang Li's avatar
Muyang Li committed
340
341
342
343
344
345
        Tensor q = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
346

347
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
348
349
350
351
352
353
354
355
356
357
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             {},
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb,
                             q.slice(0, i, i + 1),
                             k.slice(0, i, i + 1),
                             v.slice(0, i, i + 1),
                             num_tokens);
358
        }
359
360
361
362
363

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

Muyang Li's avatar
Muyang Li committed
364
365
366
        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head},
                                    norm_hidden_states.scalar_type(),
                                    norm_hidden_states.device());
367
368
369

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

370
371
372
373
        if (batch_size == 1 || num_tokens_pad == num_tokens) {
            attn_output = o.slice(1, 0, num_tokens);
        } else {
            attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
fengzch-das's avatar
fengzch-das committed
374
            checkCUDA(cudaMemcpy2DAsync(attn_output.data_ptr(),
Muyang Li's avatar
Muyang Li committed
375
376
377
378
379
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        o.data_ptr(),
                                        o.stride(0) * o.scalar_size(),
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
380
381
                                        cudaMemcpyDeviceToDevice,
                                        getCurrentCUDAStream()));
382
        }
383
384
385
386
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
387
388
389
390
391
392
393
394
    debug("raw_attn_output", attn_output);

    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
395
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
396
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
397

398
399
    // kernels::mul_add(hidden_states, gate, residual);
    kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
400

fengzch-das's avatar
fengzch-das committed
401
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
402
403
404
405

    return hidden_states;
}

Muyang Li's avatar
Muyang Li committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
JointTransformerBlock::JointTransformerBlock(int dim,
                                             int num_attention_heads,
                                             int attention_head_dim,
                                             bool context_pre_only,
                                             bool use_fp4,
                                             Tensor::ScalarType dtype,
                                             Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      context_pre_only(context_pre_only), norm1(dim, false, dtype, device),
      norm1_context(dim, context_pre_only, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
      norm_k(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
      norm_added_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
      norm2(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
      mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
      mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
      mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) {
    registerChildren(norm1, "norm1")(norm1_context, "norm1_context")(qkv_proj, "qkv_proj")(qkv_proj_context,
                                                                                           "qkv_proj_context")(
        norm_q, "norm_q")(norm_k, "norm_k")(norm_added_q, "norm_added_q")(norm_added_k, "norm_added_k")(attn, "attn")(
        out_proj, "out_proj")(out_proj_context, "out_proj_context")(norm2, "norm2")(norm2_context, "norm2_context")(
        mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(mlp_context_fc1, "mlp_context_fc1")(mlp_context_fc2, "mlp_context_fc2");
Zhekai Zhang's avatar
Zhekai Zhang committed
430
431
432
433
}

// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
Muyang Li's avatar
Muyang Li committed
434
435
436
437
438
439
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
                                                          Tensor encoder_hidden_states,
                                                          Tensor temb,
                                                          Tensor rotary_emb,
                                                          Tensor rotary_emb_context,
                                                          float sparsityRatio) {
Zhekai Zhang's avatar
Zhekai Zhang committed
440
441
442
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
443
    nvtxRangePushA("JointTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
444

fengzch-das's avatar
fengzch-das committed
445
    nvtxRangePushA("AdaNorm");
Zhekai Zhang's avatar
Zhekai Zhang committed
446
447

    int num_tokens_img = hidden_states.shape[1];
448
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
449

Zhekai Zhang's avatar
Zhekai Zhang committed
450
451
452
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

Muyang Li's avatar
Muyang Li committed
453
454
455
456
    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
457
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
458

Muyang Li's avatar
Muyang Li committed
459
    auto norm1_output         = norm1.forward(hidden_states, temb);
Zhekai Zhang's avatar
Zhekai Zhang committed
460
461
462
463
464
465
466
467
468
469
470
471
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
472
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
473

fengzch-das's avatar
fengzch-das committed
474
    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
475

476
477
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
478

479
480
481
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
482

483
484
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
485

486
        {
fengzch-das's avatar
fengzch-das committed
487
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
488

489
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
490

491
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
Muyang Li's avatar
Muyang Li committed
492
493
494
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
495

Muyang Li's avatar
Muyang Li committed
496
497
498
499
            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
500

501
502
503
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Muyang Li's avatar
Muyang Li committed
504
505
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
506

Muyang Li's avatar
Muyang Li committed
507
508
                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
509
                Tensor pool_qkv_context = pool.valid()
Muyang Li's avatar
Muyang Li committed
510
511
512
513
514
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
515

516
517
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
518

519
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
520

Muyang Li's avatar
Muyang Li committed
521
522
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
523
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
524

525
526
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
527

528
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
529

Muyang Li's avatar
Muyang Li committed
530
531
532
533
534
535
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
536
537
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
538

fengzch-das's avatar
fengzch-das committed
539
            nvtxRangePop();
540
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
541

542
543
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
544

545
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
546

fengzch-das's avatar
fengzch-das committed
547
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
548

549
550
551
552
553
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
554

fengzch-das's avatar
fengzch-das committed
555
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
556

557
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
558

559
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
560

561
562
563
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
564

565
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
566

567
        {
fengzch-das's avatar
fengzch-das committed
568
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
569

Muyang Li's avatar
Muyang Li committed
570
571
572
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
573
574
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
575

576
577
            for (int i = 0; i < batch_size; i++) {
                // img first
Muyang Li's avatar
Muyang Li committed
578
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
579
                auto sliceTxt = [&](Tensor x) {
Muyang Li's avatar
Muyang Li committed
580
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
581
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
582

Muyang Li's avatar
Muyang Li committed
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
604
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
605

606
607
608
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
609

fengzch-das's avatar
fengzch-das committed
610
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
611
612
        }

Muyang Li's avatar
Muyang Li committed
613
614
615
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
616

fengzch-das's avatar
fengzch-das committed
617
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
618

619
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
620

fengzch-das's avatar
fengzch-das committed
621
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
622

Muyang Li's avatar
Muyang Li committed
623
624
        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
625
626
627
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
628
629
630
631

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
632
        nvtxRangePushA("o_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
633
634
635

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

636
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
637
638
639

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
640
641
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
642
        } else {
Muyang Li's avatar
Muyang Li committed
643
644
645
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
646
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
647
648
649
650
651
652
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
653
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
654
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
655
        }
muyangli's avatar
muyangli committed
656

Zhekai Zhang's avatar
Zhekai Zhang committed
657
658
659
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
660
661
        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
662
663
664
        debug("img.attn_output", attn_output);

#if 1
665
666
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
667
668
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
669
670
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
671
672
673
674
675
676

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
677
678
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
679
680
681
682
683
684
685
686
687
688
689
690

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
691
692
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
693
694
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
695
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
696
697
698
699
700

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
Muyang Li's avatar
Muyang Li committed
701
        return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
702
703
704
    }

    {
fengzch-das's avatar
fengzch-das committed
705
        nvtxRangePushA("o_proj_context");
Zhekai Zhang's avatar
Zhekai Zhang committed
706
707
708
709
710

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
711
712
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
713
        } else {
Muyang Li's avatar
Muyang Li committed
714
715
716
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
717
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
718
719
720
721
722
723
724
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
725
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
726
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
727
        }
muyangli's avatar
muyangli committed
728

Zhekai Zhang's avatar
Zhekai Zhang committed
729
730
731
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
732
733
734
        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
735
736
737
        debug("context.attn_output", attn_output);

#if 1
738
739
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
740
741
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
742
743
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
744
745
746
747
748
749

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
750
751
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
752
753
754
755
756

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
757

Zhekai Zhang's avatar
Zhekai Zhang committed
758
        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
Muyang Li's avatar
Muyang Li committed
759
760
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
Zhekai Zhang's avatar
Zhekai Zhang committed
761
762
763
764
765
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
766
767
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
768
769
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
770
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
771
772
773
774

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
775
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
776

Muyang Li's avatar
Muyang Li committed
777
    return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
778
}
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
Tensor JointTransformerBlock::get_q_heads(Tensor hidden_states,
                                          Tensor encoder_hidden_states,
                                          Tensor temb,
                                          Tensor rotary_emb,
                                          Tensor rotary_emb_context,
                                          float sparsityRatio) {
    int batch_size     = hidden_states.shape[0];
    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    // Apply AdaNorm.
    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

    Tensor concat = Tensor::allocate(
        {batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());

    const bool blockSparse  = sparsityRatio > 0;
    constexpr int POOL_SIZE = Attention::POOL_SIZE;
    const int poolTokens    = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
    Tensor pool =
        blockSparse
            ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
            : Tensor{};

    // QKV Projection.
    for (int i = 0; i < batch_size; i++) {
        Tensor qkv         = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
        Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
        Tensor pool_qkv    = pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
        Tensor pool_qkv_context =
            pool.valid() ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, poolTokens) : Tensor{};

        qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
        qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                 qkv_context,
                                 pool_qkv_context,
                                 norm_added_q.weight,
                                 norm_added_k.weight,
                                 rotary_emb_context);
    }

    // Extract and return q_heads.
    Tensor q_all = concat.slice(2, 0, num_heads * dim_head);
    Tensor q_img = q_all.slice(1, 0, num_tokens_img);

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;
        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;
        Tensor out       = Tensor::allocate({B, R, C}, t.scalarType, t.device());
fengzch-das's avatar
fengzch-das committed
835
        auto stream      = getCurrentCUDAStream();
836
837
838
839
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
840
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        }
        return out;
    };
    return make_contiguous(q_img);
}

std::tuple<Tensor, Tensor, Tensor> JointTransformerBlock::forward_ip_adapter_branch(Tensor hidden_states,
                                                                                    Tensor encoder_hidden_states,
                                                                                    Tensor temb,
                                                                                    Tensor rotary_emb,
                                                                                    Tensor rotary_emb_context,
                                                                                    float sparsityRatio) {
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
856
    nvtxRangePushA("JointTransformerBlock");
857

fengzch-das's avatar
fengzch-das committed
858
    nvtxRangePushA("AdaNorm");
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880

    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

    Tensor q_heads;

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;

        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;

        Tensor out = Tensor::allocate({B, R, C}, t.scalarType, t.device());

fengzch-das's avatar
fengzch-das committed
881
        auto stream = getCurrentCUDAStream();
882
883
884
885
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
886
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
        }
        return out;
    };

    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);

    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
910
    nvtxRangePop();
911

fengzch-das's avatar
fengzch-das committed
912
    auto stream = getCurrentCUDAStream();
913
914
915
916
917
918
919
920
921
922
923
924

    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;

    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;

        Tensor concat;
        Tensor pool;

        {
fengzch-das's avatar
fengzch-das committed
925
            nvtxRangePushA("qkv_proj");
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976

            const bool blockSparse = sparsityRatio > 0;

            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());

            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};

            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);

                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
                Tensor pool_qkv_context = pool.valid()
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};

                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);

                debug("rotary_emb", rotary_emb);

                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
                debug("qkv", qkv);

                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);

                debug("rotary_emb_context", rotary_emb_context);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
                debug("qkv_context", qkv_context);
            }

fengzch-das's avatar
fengzch-das committed
977
            nvtxRangePop();
978
979
980
981
982
983
984
        }

        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);

        assert(concat.shape[2] == num_heads * dim_head * 3);

fengzch-das's avatar
fengzch-das committed
985
        nvtxRangePushA("Attention");
986
987
988
989
990
991
992

        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }

fengzch-das's avatar
fengzch-das committed
993
        nvtxRangePop();
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012

        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());

        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});

        // IP_adapter
        Tensor q_all = concat.slice(2, 0, num_heads * dim_head); // [B, N_total, dim]

        Tensor q_img = q_all.slice(1, 0, num_tokens_img); // [B, N_img, dim]

        q_heads = make_contiguous(q_img);

    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;

        Tensor concat_q, concat_k, concat_v;

        {
fengzch-das's avatar
fengzch-das committed
1013
            nvtxRangePushA("qkv_proj");
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054

            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);

            for (int i = 0; i < batch_size; i++) {
                // img first
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
                auto sliceTxt = [&](Tensor x) {
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
                };

                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
            }

            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);

fengzch-das's avatar
fengzch-das committed
1055
            nvtxRangePop();
1056
1057
1058
1059
1060
1061
        }

        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());

fengzch-das's avatar
fengzch-das committed
1062
        nvtxRangePushA("Attention");
1063
1064
1065

        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));

fengzch-das's avatar
fengzch-das committed
1066
        nvtxRangePop();
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078

        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});

        q_heads = concat_q;
    } else {
        assert(false);
    }

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
1079
        nvtxRangePushA("o_proj");
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1093
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1094
1095
1096
1097
1098
1099
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1100
                                        cudaMemcpyDeviceToDevice,
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
        debug("img.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1116
1117
        nvtxRangePop();
        nvtxRangePushA("MLP");
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1142
        nvtxRangePop();
1143
1144
1145
1146
1147
1148
1149
1150
1151

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
        return {hidden_states, encoder_hidden_states, q_heads};
    }

    {
fengzch-das's avatar
fengzch-das committed
1152
        nvtxRangePushA("o_proj_context");
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1164
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1165
1166
1167
1168
1169
1170
1171
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1172
                                        cudaMemcpyDeviceToDevice,
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
        debug("context.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1189
1190
        nvtxRangePop();
        nvtxRangePushA("MLP");
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif

        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1217
        nvtxRangePop();
1218
1219
1220
1221

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
1222
    nvtxRangePop();
1223
1224
1225

    return {hidden_states, encoder_hidden_states, q_heads};
}
Zhekai Zhang's avatar
Zhekai Zhang committed
1226

Muyang Li's avatar
Muyang Li committed
1227
1228
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
    : dtype(dtype), offload(offload) {
1229
1230
    CUDADeviceContext model_construction_ctx(device.idx);

Zhekai Zhang's avatar
Zhekai Zhang committed
1231
    for (int i = 0; i < 19; i++) {
Muyang Li's avatar
Muyang Li committed
1232
1233
        transformer_blocks.push_back(
            std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1234
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1235
1236
1237
1238
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1239
1240
    }
    for (int i = 0; i < 38; i++) {
Muyang Li's avatar
Muyang Li committed
1241
1242
        single_transformer_blocks.push_back(
            std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1243
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1244
1245
1246
1247
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1248
1249
1250
    }
}

Muyang Li's avatar
Muyang Li committed
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
Tensor FluxModel::forward(Tensor hidden_states,
                          Tensor encoder_hidden_states,
                          Tensor temb,
                          Tensor rotary_emb_img,
                          Tensor rotary_emb_context,
                          Tensor rotary_emb_single,
                          Tensor controlnet_block_samples,
                          Tensor controlnet_single_block_samples,
                          bool skip_first_layer) {
    const int batch_size           = hidden_states.shape[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
1261
    const Tensor::ScalarType dtype = hidden_states.dtype();
Muyang Li's avatar
Muyang Li committed
1262
    const Device device            = hidden_states.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
1263
1264
1265
1266

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
1267
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
1268

muyangli's avatar
muyangli committed
1269
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
1270

muyangli's avatar
muyangli committed
1271
    auto compute = [&](int layer) {
Muyang Li's avatar
Muyang Li committed
1272
1273
        if (skip_first_layer && size_t(layer) == 0)
            return;
muyangli's avatar
muyangli committed
1274
1275
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
Muyang Li's avatar
Muyang Li committed
1276
1277
            std::tie(hidden_states, encoder_hidden_states) =
                block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
1278
            if (controlnet_block_samples.valid()) {
1279
1280
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1281
1282
                int interval_control =
                    ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1283
1284
1285
1286
1287
1288
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
K's avatar
K committed
1289
            if (residual_callback && layer % 2 == 0) {
K's avatar
K committed
1290
1291
                Tensor residual = residual_callback(hidden_states);
                hidden_states   = kernels::add(hidden_states, residual);
K's avatar
K committed
1292
            }
muyangli's avatar
muyangli committed
1293
1294
1295
1296
1297
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
1298
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
Muyang Li's avatar
Muyang Li committed
1299
1300
1301
                    concat.slice(0, i, i + 1)
                        .slice(1, txt_tokens, txt_tokens + img_tokens)
                        .copy_(hidden_states.slice(0, i, i + 1));
muyangli's avatar
muyangli committed
1302
                }
Muyang Li's avatar
Muyang Li committed
1303
                hidden_states         = concat;
muyangli's avatar
muyangli committed
1304
1305
1306
                encoder_hidden_states = {};
            }

Muyang Li's avatar
Muyang Li committed
1307
            auto &block   = single_transformer_blocks.at(layer - transformer_blocks.size());
muyangli's avatar
muyangli committed
1308
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
1309
            if (controlnet_single_block_samples.valid()) {
1310
1311
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1312
1313
                int interval_control =
                    ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1314
1315
1316
1317
1318
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1319
                slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1320
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Muyang Li's avatar
Muyang Li committed
1321
            }
K's avatar
K committed
1322
1323
1324
            size_t local_layer_idx = layer - transformer_blocks.size();
            if (residual_callback && local_layer_idx % 4 == 0) {
                Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
K's avatar
K committed
1325
1326
1327
                Tensor residual       = residual_callback(callback_input);
                auto slice            = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice                 = kernels::add(slice, residual);
K's avatar
K committed
1328
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Hyunsung Lee's avatar
Hyunsung Lee committed
1329
            }
muyangli's avatar
muyangli committed
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
1353
1354

    return hidden_states;
1355
1356
}

Muyang Li's avatar
Muyang Li committed
1357
1358
1359
1360
1361
1362
1363
1364
1365
std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
                                                    Tensor hidden_states,
                                                    Tensor encoder_hidden_states,
                                                    Tensor temb,
                                                    Tensor rotary_emb_img,
                                                    Tensor rotary_emb_context,
                                                    Tensor controlnet_block_samples,
                                                    Tensor controlnet_single_block_samples) {

1366
1367
1368
1369
1370
1371
1372
1373
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1374
    if (layer < transformer_blocks.size()) {
1375
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
Muyang Li's avatar
Muyang Li committed
1376
1377
1378
1379
1380
            hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    } else {
        std::tie(hidden_states, encoder_hidden_states) =
            transformer_blocks.at(layer - transformer_blocks.size())
                ->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
1381
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
1382
1383
1384
1385
1386

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
1387
1388
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
1389
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Muyang Li's avatar
Muyang Li committed
1390
        int block_index      = layer / interval_control;
Hyunsung Lee's avatar
Hyunsung Lee committed
1391
1392
1393
1394
1395
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
1396
1397
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1398
1399
        int interval_control =
            ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1400
1401
1402
1403
1404
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1405
        slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1406
1407
1408
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

1409
1410
1411
1412
1413
1414
1415
1416
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->releaseLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1417
    return {hidden_states, encoder_hidden_states};
Hyunsung Lee's avatar
Hyunsung Lee committed
1418
1419
}

1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
std::tuple<Tensor, Tensor, Tensor> FluxModel::forward_ip_adapter(size_t layer,
                                                                 Tensor hidden_states,         // [B, Nq, dim]
                                                                 Tensor encoder_hidden_states, // [B, Nt, dim]
                                                                 Tensor temb,
                                                                 Tensor rotary_emb_img, // [B, Nq, dim_head]
                                                                 Tensor rotary_emb_context,
                                                                 Tensor controlnet_block_samples,
                                                                 Tensor controlnet_single_block_samples) {
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

    std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    Tensor ip_query = transformer_blocks.at(layer)->get_q_heads(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);

    if (controlnet_block_samples.valid()) {
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
        int block_index      = layer / interval_control;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    }

    if (offload && layer > 0) {
        transformer_blocks.at(layer)->releaseLazyParams();
    }

    return {hidden_states, encoder_hidden_states, ip_query};
}

1457
1458
1459
1460
1461
1462
1463
1464
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}
Muyang Li's avatar
Muyang Li committed
1465
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor &)> cb) {
K's avatar
K committed
1466
    residual_callback = std::move(cb);
Muyang Li's avatar
Muyang Li committed
1467
}