FluxModel.cpp 66.1 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "activation.h"
fengzch's avatar
fengzch committed
6
#include "Tensor.h"
limm's avatar
limm committed
7
8
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
9

Muyang Li's avatar
Muyang Li committed
10
11
#include <pybind11/functional.h>

fengzch's avatar
fengzch committed
12
#include <flash_c_api.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
13
14
15
#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
16
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
17

fengzch's avatar
fengzch committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
Tensor call_fa_mha_fwd(Tensor &q, // batch_size x seqlen_q x num_heads x head_size
                       Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
                       Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
                       // c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
                       // c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
                       const float p_dropout,
                       const float softmax_scale,
                       bool is_causal,
                       int window_size_left,
                       int window_size_right,
                       const bool return_softmax
                       // c10::optional<at::Generator> gen_
) {
    // printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
    Tensor o = Tensor::empty_like(q);
    size_t workspace_size = mha_fwd_workspace(
        q.shape[0], q.shape[1], k.shape[1],
        q.shape[2], k.shape[2],
        q.shape[3], k.shape[3],
        false
    );

    const Device device  = q.device();
    Tensor workspace = Tensor::allocate({1, 1, 1, (int)workspace_size}, Tensor::INT8, device);

    mha_fwd(
        q.data_ptr(), k.data_ptr(), v.data_ptr(), o.data_ptr(),
        nullptr,                                    //* alibi
        nullptr,                                    //* rng_state
        workspace.data_ptr(),                       //* workspace
        q.shape[0], q.shape[1], k.shape[1],         //* sizes
        q.shape[2], k.shape[2],
        q.shape[3], k.shape[3],

        q.stride(0), q.stride(1), q.stride(2), q.stride(3),                         //* q strides
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),                         //* k strides
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),                         //* v strides
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),                         //* o strides

        1, 1,                                           //* alibi strides
        p_dropout,                                      //* p_dropout
        softmax_scale,                                  //* softmax_scale
        is_causal,                                      //* is_causal
        window_size_left,
        window_size_right,                              //* window sizes
        0.0f,                                           //* softcap
        return_softmax,                                 //* return_softmax
        0,                                              //* seed
        q.scalar_type() == Tensor::ScalarType::BF16,    //* is_bf16
        false                                           //* is_bhsd
    );

    return o;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
73
Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
fengzch's avatar
fengzch committed
74
    std::cout << "Called forward_mlp " << std::endl;
Muyang Li's avatar
Muyang Li committed
75
76
    Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
        fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
Zhekai Zhang's avatar
Zhekai Zhang committed
77
78
79
80
81
82
83
84
85
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }

Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
86
87
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
91
92
93
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }

Muyang Li's avatar
Muyang Li committed
94
95
96
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device)
    : dim(dim), linear(dim, 3 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
99
100
101
102
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
103
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
107
108
109
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
110

111
112
    // kernels::mul_add(norm_x, scale_msa, shift_msa);
    kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
113
114
115
    return Output{norm_x, gate_msa};
}

Muyang Li's avatar
Muyang Li committed
116
117
118
119
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device)
    : dim(dim), pre_only(pre_only), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
      norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
120
121
122
123
124
125
126
127
128
129
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
130
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
131
132
133
134
135
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

136
137
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
138
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
139

Zhekai Zhang's avatar
Zhekai Zhang committed
140
141
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
142
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
143
144
145
146
147
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

148
149
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
150
151
152
153
154
155
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}

Muyang Li's avatar
Muyang Li committed
156
157
Attention::Attention(int num_heads, int dim_head, Device device)
    : num_heads(num_heads), dim_head(dim_head), force_fp16(false) {
Zhekai Zhang's avatar
Zhekai Zhang committed
158
159
160
161
162
163
164
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

165
166
167
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
168
    const Device device  = qkv.device();
169
170
171
172
173
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
174
175
176
    Tensor q        = reshaped.slice(2, 0, num_heads);
    Tensor k        = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(2, num_heads * 2, num_heads * 3);
177

fengzch's avatar
fengzch committed
178
    Tensor raw_attn_output = call_fa_mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false);
179
180
181
182
183

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
Muyang Li's avatar
Muyang Li committed
184

185
186
187
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
188
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
189
190
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
191
192
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
193
    const Device device  = qkv.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
194
195
196
197
198
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
Muyang Li's avatar
Muyang Li committed
199
    const int pool_tokens   = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
200
201
202
203
204
205
206
207
208
209
210
211
212

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
Muyang Li's avatar
Muyang Li committed
213
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
214
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
215
216
217
            Tensor pool_q = pool_qkv.slice(0, i, i + 1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i + 1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i + 1);
Zhekai Zhang's avatar
Zhekai Zhang committed
218
219
220
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
221

muyangli's avatar
muyangli committed
222
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
Muyang Li's avatar
Muyang Li committed
237
        cu_seqlens_cpu                        = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
Zhekai Zhang's avatar
Zhekai Zhang committed
238
239
240
241
242
243
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

244
245
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
246
        kernels::cast(qkv, tmp);
247
248
249
250
251
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
252
253
254
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
255
256
257
    Tensor q        = reshaped.slice(1, 0, num_heads);
    Tensor k        = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(1, num_heads * 2, num_heads * 3);
Zhekai Zhang's avatar
Zhekai Zhang committed
258
259
260

    spdlog::debug("q,k,v={}", q.shape.str());

fengzch's avatar
fengzch committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    // Tensor raw_attn_output = mha_fwd_block(q,
    //                                        k,
    //                                        v,
    //                                        cu_seqlens,
    //                                        cu_seqlens,
    //                                        POOL_SIZE,
    //                                        POOL_SIZE,
    //                                        headmask_type,
    //                                        {},
    //                                        blockmask,
    //                                        num_tokens,
    //                                        num_tokens,
    //                                        0.0f,
    //                                        pow(q.shape[-1], (-0.5)),
    //                                        false,
    //                                        false,
    //                                        false,
    //                                        -1,
    //                                        -1)
    //                              .front();
281
    Tensor raw_attn_output = Tensor::ones({batch_size * num_tokens, num_heads, dim_head}, Tensor::FP16, Device::cuda());
fengzch's avatar
fengzch committed
282
    std::cout << "mha_fwd_block not support !!!" << std::endl;
283
284
285
286
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
287
        kernels::cast(raw_attn_output, tmp);
288
289
290
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
305
306
307
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
308
309
310
311
312
313
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
314
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
315
316
317
318
319
320
321
322
323
324
325
326
327
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

328
329
330
331
332
333
334
335
336
337
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

Muyang Li's avatar
Muyang Li committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim,
                                                       int num_attention_heads,
                                                       int attention_head_dim,
                                                       int mlp_ratio,
                                                       bool use_fp4,
                                                       Tensor::ScalarType dtype,
                                                       Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      mlp_hidden_dim(dim * mlp_ratio), norm(dim, dtype, device),
      mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
      mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      norm_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device) {
    registerChildren(norm, "norm")(mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(qkv_proj, "qkv_proj")(norm_q, "norm_q")(
        norm_k, "norm_k")(attn, "attn")(out_proj, "out_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
354
355
356
357
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

fengzch-das's avatar
fengzch-das committed
358
    nvtxRangePushA("FluxSingleTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
359
360
361
362
363
364
365
366
367
368

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

369
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
370
371

    debug("rotary_emb", rotary_emb);
372
373

    if (attnImpl == AttentionImpl::FlashAttention2) {
Muyang Li's avatar
Muyang Li committed
374
375
        Tensor qkv = Tensor::allocate(
            {batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
376
377
378
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

379
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
380
381
382
383
384
385
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             qkv.slice(0, i, i + 1),
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb);
386
        }
387
388
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
389

390
391
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
392
393
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
394
        // assert(batch_size == 1);
395
396
397

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

Muyang Li's avatar
Muyang Li committed
398
399
400
401
402
403
        Tensor q = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
404

405
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
406
407
408
409
410
411
412
413
414
415
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             {},
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb,
                             q.slice(0, i, i + 1),
                             k.slice(0, i, i + 1),
                             v.slice(0, i, i + 1),
                             num_tokens);
416
        }
417
418
419
420
421

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

Muyang Li's avatar
Muyang Li committed
422
423
424
        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head},
                                    norm_hidden_states.scalar_type(),
                                    norm_hidden_states.device());
425
426
427

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

428
429
430
431
        if (batch_size == 1 || num_tokens_pad == num_tokens) {
            attn_output = o.slice(1, 0, num_tokens);
        } else {
            attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
fengzch-das's avatar
fengzch-das committed
432
            checkCUDA(cudaMemcpy2DAsync(attn_output.data_ptr(),
Muyang Li's avatar
Muyang Li committed
433
434
435
436
437
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        o.data_ptr(),
                                        o.stride(0) * o.scalar_size(),
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
438
439
                                        cudaMemcpyDeviceToDevice,
                                        getCurrentCUDAStream()));
440
        }
441
442
443
444
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
445
446
447
448
449
450
451
452
    debug("raw_attn_output", attn_output);

    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
453
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
454
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
455

456
457
    // kernels::mul_add(hidden_states, gate, residual);
    kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
458

fengzch-das's avatar
fengzch-das committed
459
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
460
461
462
463

    return hidden_states;
}

Muyang Li's avatar
Muyang Li committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
JointTransformerBlock::JointTransformerBlock(int dim,
                                             int num_attention_heads,
                                             int attention_head_dim,
                                             bool context_pre_only,
                                             bool use_fp4,
                                             Tensor::ScalarType dtype,
                                             Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      context_pre_only(context_pre_only), norm1(dim, false, dtype, device),
      norm1_context(dim, context_pre_only, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
      norm_k(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
      norm_added_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
      norm2(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
      mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
      mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
      mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) {
    registerChildren(norm1, "norm1")(norm1_context, "norm1_context")(qkv_proj, "qkv_proj")(qkv_proj_context,
                                                                                           "qkv_proj_context")(
        norm_q, "norm_q")(norm_k, "norm_k")(norm_added_q, "norm_added_q")(norm_added_k, "norm_added_k")(attn, "attn")(
        out_proj, "out_proj")(out_proj_context, "out_proj_context")(norm2, "norm2")(norm2_context, "norm2_context")(
        mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(mlp_context_fc1, "mlp_context_fc1")(mlp_context_fc2, "mlp_context_fc2");
Zhekai Zhang's avatar
Zhekai Zhang committed
488
489
490
491
}

// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
Muyang Li's avatar
Muyang Li committed
492
493
494
495
496
497
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
                                                          Tensor encoder_hidden_states,
                                                          Tensor temb,
                                                          Tensor rotary_emb,
                                                          Tensor rotary_emb_context,
                                                          float sparsityRatio) {
Zhekai Zhang's avatar
Zhekai Zhang committed
498
499
500
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
501
    nvtxRangePushA("JointTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
502

fengzch-das's avatar
fengzch-das committed
503
    nvtxRangePushA("AdaNorm");
Zhekai Zhang's avatar
Zhekai Zhang committed
504
505

    int num_tokens_img = hidden_states.shape[1];
506
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
507

Zhekai Zhang's avatar
Zhekai Zhang committed
508
509
510
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

Muyang Li's avatar
Muyang Li committed
511
512
513
514
    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
515
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
516

Muyang Li's avatar
Muyang Li committed
517
    auto norm1_output         = norm1.forward(hidden_states, temb);
Zhekai Zhang's avatar
Zhekai Zhang committed
518
519
520
521
522
523
524
525
526
527
528
529
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
530
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
531

fengzch-das's avatar
fengzch-das committed
532
    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
533

534
535
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
536

537
538
539
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
540

541
542
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
543

544
        {
fengzch-das's avatar
fengzch-das committed
545
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
546

547
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
548

549
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
Muyang Li's avatar
Muyang Li committed
550
551
552
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
553

Muyang Li's avatar
Muyang Li committed
554
555
556
557
            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
558

559
560
561
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Muyang Li's avatar
Muyang Li committed
562
563
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
564

Muyang Li's avatar
Muyang Li committed
565
566
                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
567
                Tensor pool_qkv_context = pool.valid()
Muyang Li's avatar
Muyang Li committed
568
569
570
571
572
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
573

574
575
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
576

577
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
578

Muyang Li's avatar
Muyang Li committed
579
580
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
581
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
582

583
584
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
585

586
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
587

Muyang Li's avatar
Muyang Li committed
588
589
590
591
592
593
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
594
595
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
596

fengzch-das's avatar
fengzch-das committed
597
            nvtxRangePop();
598
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
599

600
601
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
602

603
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
604

fengzch-das's avatar
fengzch-das committed
605
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
606

607
608
609
610
611
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
612

fengzch-das's avatar
fengzch-das committed
613
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
614

615
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
616

617
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
618

619
620
621
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
622

623
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
624

625
        {
fengzch-das's avatar
fengzch-das committed
626
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
627

Muyang Li's avatar
Muyang Li committed
628
629
630
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
631
632
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
633

634
635
            for (int i = 0; i < batch_size; i++) {
                // img first
Muyang Li's avatar
Muyang Li committed
636
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
637
                auto sliceTxt = [&](Tensor x) {
Muyang Li's avatar
Muyang Li committed
638
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
639
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
640

Muyang Li's avatar
Muyang Li committed
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
662
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
663

664
665
666
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
667

fengzch-das's avatar
fengzch-das committed
668
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
669
670
        }

Muyang Li's avatar
Muyang Li committed
671
672
673
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
674

fengzch-das's avatar
fengzch-das committed
675
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
676

677
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
678

fengzch-das's avatar
fengzch-das committed
679
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
680

Muyang Li's avatar
Muyang Li committed
681
682
        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
683
684
685
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
686
687
688
689

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
690
        nvtxRangePushA("o_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
691
692
693

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

694
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
695
696
697

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
698
699
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
700
        } else {
Muyang Li's avatar
Muyang Li committed
701
702
703
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
704
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
705
706
707
708
709
710
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
711
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
712
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
713
        }
muyangli's avatar
muyangli committed
714

Zhekai Zhang's avatar
Zhekai Zhang committed
715
716
717
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
718
719
        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
720
721
722
        debug("img.attn_output", attn_output);

#if 1
723
724
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
725
726
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
727
728
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
729
730
731
732
733
734

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
735
736
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
737
738
739
740
741
742
743
744
745
746
747
748

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
749
750
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
751
752
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
753
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
754
755
756
757
758

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
Muyang Li's avatar
Muyang Li committed
759
        return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
760
761
762
    }

    {
fengzch-das's avatar
fengzch-das committed
763
        nvtxRangePushA("o_proj_context");
Zhekai Zhang's avatar
Zhekai Zhang committed
764
765
766
767
768

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
769
770
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
771
        } else {
Muyang Li's avatar
Muyang Li committed
772
773
774
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
775
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
776
777
778
779
780
781
782
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
783
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
784
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
785
        }
muyangli's avatar
muyangli committed
786

Zhekai Zhang's avatar
Zhekai Zhang committed
787
788
789
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
790
791
792
        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
793
794
795
        debug("context.attn_output", attn_output);

#if 1
796
797
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
798
799
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
800
801
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
802
803
804
805
806
807

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
808
809
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
810
811
812
813
814

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
815

Zhekai Zhang's avatar
Zhekai Zhang committed
816
        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
Muyang Li's avatar
Muyang Li committed
817
818
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
Zhekai Zhang's avatar
Zhekai Zhang committed
819
820
821
822
823
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
824
825
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
826
827
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
828
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
829
830
831
832

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
833
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
834

Muyang Li's avatar
Muyang Li committed
835
    return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
836
}
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
Tensor JointTransformerBlock::get_q_heads(Tensor hidden_states,
                                          Tensor encoder_hidden_states,
                                          Tensor temb,
                                          Tensor rotary_emb,
                                          Tensor rotary_emb_context,
                                          float sparsityRatio) {
    int batch_size     = hidden_states.shape[0];
    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    // Apply AdaNorm.
    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

    Tensor concat = Tensor::allocate(
        {batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());

    const bool blockSparse  = sparsityRatio > 0;
    constexpr int POOL_SIZE = Attention::POOL_SIZE;
    const int poolTokens    = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
    Tensor pool =
        blockSparse
            ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
            : Tensor{};

    // QKV Projection.
    for (int i = 0; i < batch_size; i++) {
        Tensor qkv         = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
        Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
        Tensor pool_qkv    = pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
        Tensor pool_qkv_context =
            pool.valid() ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, poolTokens) : Tensor{};

        qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
        qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                 qkv_context,
                                 pool_qkv_context,
                                 norm_added_q.weight,
                                 norm_added_k.weight,
                                 rotary_emb_context);
    }

    // Extract and return q_heads.
    Tensor q_all = concat.slice(2, 0, num_heads * dim_head);
    Tensor q_img = q_all.slice(1, 0, num_tokens_img);

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;
        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;
        Tensor out       = Tensor::allocate({B, R, C}, t.scalarType, t.device());
fengzch-das's avatar
fengzch-das committed
893
        auto stream      = getCurrentCUDAStream();
894
895
896
897
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
898
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        }
        return out;
    };
    return make_contiguous(q_img);
}

std::tuple<Tensor, Tensor, Tensor> JointTransformerBlock::forward_ip_adapter_branch(Tensor hidden_states,
                                                                                    Tensor encoder_hidden_states,
                                                                                    Tensor temb,
                                                                                    Tensor rotary_emb,
                                                                                    Tensor rotary_emb_context,
                                                                                    float sparsityRatio) {
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
914
    nvtxRangePushA("JointTransformerBlock");
915

fengzch-das's avatar
fengzch-das committed
916
    nvtxRangePushA("AdaNorm");
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938

    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

    Tensor q_heads;

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;

        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;

        Tensor out = Tensor::allocate({B, R, C}, t.scalarType, t.device());

fengzch-das's avatar
fengzch-das committed
939
        auto stream = getCurrentCUDAStream();
940
941
942
943
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
944
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
        }
        return out;
    };

    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);

    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
968
    nvtxRangePop();
969

fengzch-das's avatar
fengzch-das committed
970
    auto stream = getCurrentCUDAStream();
971
972
973
974
975
976
977
978
979
980
981
982

    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;

    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;

        Tensor concat;
        Tensor pool;

        {
fengzch-das's avatar
fengzch-das committed
983
            nvtxRangePushA("qkv_proj");
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034

            const bool blockSparse = sparsityRatio > 0;

            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());

            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};

            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);

                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
                Tensor pool_qkv_context = pool.valid()
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};

                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);

                debug("rotary_emb", rotary_emb);

                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
                debug("qkv", qkv);

                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);

                debug("rotary_emb_context", rotary_emb_context);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
                debug("qkv_context", qkv_context);
            }

fengzch-das's avatar
fengzch-das committed
1035
            nvtxRangePop();
1036
1037
1038
1039
1040
1041
1042
        }

        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);

        assert(concat.shape[2] == num_heads * dim_head * 3);

fengzch-das's avatar
fengzch-das committed
1043
        nvtxRangePushA("Attention");
1044
1045
1046
1047
1048
1049
1050

        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }

fengzch-das's avatar
fengzch-das committed
1051
        nvtxRangePop();
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070

        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());

        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});

        // IP_adapter
        Tensor q_all = concat.slice(2, 0, num_heads * dim_head); // [B, N_total, dim]

        Tensor q_img = q_all.slice(1, 0, num_tokens_img); // [B, N_img, dim]

        q_heads = make_contiguous(q_img);

    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;

        Tensor concat_q, concat_k, concat_v;

        {
fengzch-das's avatar
fengzch-das committed
1071
            nvtxRangePushA("qkv_proj");
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112

            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);

            for (int i = 0; i < batch_size; i++) {
                // img first
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
                auto sliceTxt = [&](Tensor x) {
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
                };

                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
            }

            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);

fengzch-das's avatar
fengzch-das committed
1113
            nvtxRangePop();
1114
1115
1116
1117
1118
1119
        }

        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());

fengzch-das's avatar
fengzch-das committed
1120
        nvtxRangePushA("Attention");
1121
1122
1123

        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));

fengzch-das's avatar
fengzch-das committed
1124
        nvtxRangePop();
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136

        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});

        q_heads = concat_q;
    } else {
        assert(false);
    }

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
1137
        nvtxRangePushA("o_proj");
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1151
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1152
1153
1154
1155
1156
1157
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1158
                                        cudaMemcpyDeviceToDevice,
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
        debug("img.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1174
1175
        nvtxRangePop();
        nvtxRangePushA("MLP");
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1200
        nvtxRangePop();
1201
1202
1203
1204
1205
1206
1207
1208
1209

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
        return {hidden_states, encoder_hidden_states, q_heads};
    }

    {
fengzch-das's avatar
fengzch-das committed
1210
        nvtxRangePushA("o_proj_context");
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1222
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1223
1224
1225
1226
1227
1228
1229
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1230
                                        cudaMemcpyDeviceToDevice,
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
        debug("context.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1247
1248
        nvtxRangePop();
        nvtxRangePushA("MLP");
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif

        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1275
        nvtxRangePop();
1276
1277
1278
1279

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
1280
    nvtxRangePop();
1281
1282
1283

    return {hidden_states, encoder_hidden_states, q_heads};
}
Zhekai Zhang's avatar
Zhekai Zhang committed
1284

Muyang Li's avatar
Muyang Li committed
1285
1286
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
    : dtype(dtype), offload(offload) {
1287
1288
    CUDADeviceContext model_construction_ctx(device.idx);

Zhekai Zhang's avatar
Zhekai Zhang committed
1289
    for (int i = 0; i < 19; i++) {
Muyang Li's avatar
Muyang Li committed
1290
1291
        transformer_blocks.push_back(
            std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1292
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1293
1294
1295
1296
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1297
1298
    }
    for (int i = 0; i < 38; i++) {
Muyang Li's avatar
Muyang Li committed
1299
1300
        single_transformer_blocks.push_back(
            std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1301
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1302
1303
1304
1305
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1306
1307
1308
    }
}

Muyang Li's avatar
Muyang Li committed
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
Tensor FluxModel::forward(Tensor hidden_states,
                          Tensor encoder_hidden_states,
                          Tensor temb,
                          Tensor rotary_emb_img,
                          Tensor rotary_emb_context,
                          Tensor rotary_emb_single,
                          Tensor controlnet_block_samples,
                          Tensor controlnet_single_block_samples,
                          bool skip_first_layer) {
    const int batch_size           = hidden_states.shape[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
1319
    const Tensor::ScalarType dtype = hidden_states.dtype();
Muyang Li's avatar
Muyang Li committed
1320
    const Device device            = hidden_states.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
1321
1322
1323
1324

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
1325
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
1326

muyangli's avatar
muyangli committed
1327
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
1328

muyangli's avatar
muyangli committed
1329
    auto compute = [&](int layer) {
Muyang Li's avatar
Muyang Li committed
1330
1331
        if (skip_first_layer && size_t(layer) == 0)
            return;
muyangli's avatar
muyangli committed
1332
1333
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
Muyang Li's avatar
Muyang Li committed
1334
1335
            std::tie(hidden_states, encoder_hidden_states) =
                block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
1336
            if (controlnet_block_samples.valid()) {
1337
1338
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1339
1340
                int interval_control =
                    ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1341
1342
1343
1344
1345
1346
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
K's avatar
K committed
1347
            if (residual_callback && layer % 2 == 0) {
K's avatar
K committed
1348
1349
                Tensor residual = residual_callback(hidden_states);
                hidden_states   = kernels::add(hidden_states, residual);
K's avatar
K committed
1350
            }
muyangli's avatar
muyangli committed
1351
1352
1353
1354
1355
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
1356
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
Muyang Li's avatar
Muyang Li committed
1357
1358
1359
                    concat.slice(0, i, i + 1)
                        .slice(1, txt_tokens, txt_tokens + img_tokens)
                        .copy_(hidden_states.slice(0, i, i + 1));
muyangli's avatar
muyangli committed
1360
                }
Muyang Li's avatar
Muyang Li committed
1361
                hidden_states         = concat;
muyangli's avatar
muyangli committed
1362
1363
1364
                encoder_hidden_states = {};
            }

Muyang Li's avatar
Muyang Li committed
1365
            auto &block   = single_transformer_blocks.at(layer - transformer_blocks.size());
muyangli's avatar
muyangli committed
1366
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
1367
            if (controlnet_single_block_samples.valid()) {
1368
1369
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1370
1371
                int interval_control =
                    ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1372
1373
1374
1375
1376
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1377
                slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1378
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Muyang Li's avatar
Muyang Li committed
1379
            }
K's avatar
K committed
1380
1381
1382
            size_t local_layer_idx = layer - transformer_blocks.size();
            if (residual_callback && local_layer_idx % 4 == 0) {
                Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
K's avatar
K committed
1383
1384
1385
                Tensor residual       = residual_callback(callback_input);
                auto slice            = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice                 = kernels::add(slice, residual);
K's avatar
K committed
1386
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Hyunsung Lee's avatar
Hyunsung Lee committed
1387
            }
muyangli's avatar
muyangli committed
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
1411
1412

    return hidden_states;
1413
1414
}

Muyang Li's avatar
Muyang Li committed
1415
1416
1417
1418
1419
1420
1421
1422
1423
std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
                                                    Tensor hidden_states,
                                                    Tensor encoder_hidden_states,
                                                    Tensor temb,
                                                    Tensor rotary_emb_img,
                                                    Tensor rotary_emb_context,
                                                    Tensor controlnet_block_samples,
                                                    Tensor controlnet_single_block_samples) {

1424
1425
1426
1427
1428
1429
1430
1431
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1432
    if (layer < transformer_blocks.size()) {
1433
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
Muyang Li's avatar
Muyang Li committed
1434
1435
1436
1437
1438
            hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    } else {
        std::tie(hidden_states, encoder_hidden_states) =
            transformer_blocks.at(layer - transformer_blocks.size())
                ->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
1439
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
1440
1441
1442
1443
1444

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
1445
1446
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
1447
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Muyang Li's avatar
Muyang Li committed
1448
        int block_index      = layer / interval_control;
Hyunsung Lee's avatar
Hyunsung Lee committed
1449
1450
1451
1452
1453
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
1454
1455
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1456
1457
        int interval_control =
            ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1458
1459
1460
1461
1462
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1463
        slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1464
1465
1466
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

1467
1468
1469
1470
1471
1472
1473
1474
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->releaseLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1475
    return {hidden_states, encoder_hidden_states};
Hyunsung Lee's avatar
Hyunsung Lee committed
1476
1477
}

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
std::tuple<Tensor, Tensor, Tensor> FluxModel::forward_ip_adapter(size_t layer,
                                                                 Tensor hidden_states,         // [B, Nq, dim]
                                                                 Tensor encoder_hidden_states, // [B, Nt, dim]
                                                                 Tensor temb,
                                                                 Tensor rotary_emb_img, // [B, Nq, dim_head]
                                                                 Tensor rotary_emb_context,
                                                                 Tensor controlnet_block_samples,
                                                                 Tensor controlnet_single_block_samples) {
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

    std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    Tensor ip_query = transformer_blocks.at(layer)->get_q_heads(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);

    if (controlnet_block_samples.valid()) {
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
        int block_index      = layer / interval_control;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    }

    if (offload && layer > 0) {
        transformer_blocks.at(layer)->releaseLazyParams();
    }

    return {hidden_states, encoder_hidden_states, ip_query};
}

1515
1516
1517
1518
1519
1520
1521
1522
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}
Muyang Li's avatar
Muyang Li committed
1523
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor &)> cb) {
K's avatar
K committed
1524
    residual_callback = std::move(cb);
Muyang Li's avatar
Muyang Li committed
1525
}