FluxModel.cpp 66 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "activation.h"
fengzch's avatar
fengzch committed
6
#include "Tensor.h"
limm's avatar
limm committed
7
8
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
9

Muyang Li's avatar
Muyang Li committed
10
11
#include <pybind11/functional.h>

fengzch's avatar
fengzch committed
12
#include <flash_c_api.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
13
14
15
#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
16
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
17

fengzch's avatar
fengzch committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
Tensor call_fa_mha_fwd(Tensor &q, // batch_size x seqlen_q x num_heads x head_size
                       Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
                       Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
                       // c10::optional<at::Tensor> &out_,             // batch_size x seqlen_q x num_heads x head_size
                       // c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
                       const float p_dropout,
                       const float softmax_scale,
                       bool is_causal,
                       int window_size_left,
                       int window_size_right,
                       const bool return_softmax
                       // c10::optional<at::Generator> gen_
) {
    // printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
    Tensor o = Tensor::empty_like(q);
    size_t workspace_size = mha_fwd_workspace(
        q.shape[0], q.shape[1], k.shape[1],
        q.shape[2], k.shape[2],
        q.shape[3], k.shape[3],
        false
    );

    const Device device  = q.device();
    Tensor workspace = Tensor::allocate({1, 1, 1, (int)workspace_size}, Tensor::INT8, device);

    mha_fwd(
        q.data_ptr(), k.data_ptr(), v.data_ptr(), o.data_ptr(),
        nullptr,                                    //* alibi
        nullptr,                                    //* rng_state
        workspace.data_ptr(),                       //* workspace
        q.shape[0], q.shape[1], k.shape[1],         //* sizes
        q.shape[2], k.shape[2],
        q.shape[3], k.shape[3],

        q.stride(0), q.stride(1), q.stride(2), q.stride(3),                         //* q strides
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),                         //* k strides
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),                         //* v strides
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),                         //* o strides

        1, 1,                                           //* alibi strides
        p_dropout,                                      //* p_dropout
        softmax_scale,                                  //* softmax_scale
        is_causal,                                      //* is_causal
        window_size_left,
        window_size_right,                              //* window sizes
        0.0f,                                           //* softcap
        return_softmax,                                 //* return_softmax
        0,                                              //* seed
        q.scalar_type() == Tensor::ScalarType::BF16,    //* is_bf16
        false                                           //* is_bhsd
    );

    return o;
}

Zhekai Zhang's avatar
Zhekai Zhang committed
73
Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
fengzch's avatar
fengzch committed
74
    std::cout << "Called forward_mlp " << std::endl;
Muyang Li's avatar
Muyang Li committed
75
76
    Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
        fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
Zhekai Zhang's avatar
Zhekai Zhang committed
77
78
79
80
81
82
83
84
85
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }

Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
86
87
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
91
92
93
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }

Muyang Li's avatar
Muyang Li committed
94
95
96
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device)
    : dim(dim), linear(dim, 3 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
99
100
101
102
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
103
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
107
108
109
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
110

111
112
    // kernels::mul_add(norm_x, scale_msa, shift_msa);
    kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
113
114
115
    return Output{norm_x, gate_msa};
}

Muyang Li's avatar
Muyang Li committed
116
117
118
119
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device)
    : dim(dim), pre_only(pre_only), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
      norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
120
121
122
123
124
125
126
127
128
129
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
130
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
131
132
133
134
135
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

136
137
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
138
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
139

Zhekai Zhang's avatar
Zhekai Zhang committed
140
141
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
142
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
143
144
145
146
147
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

148
149
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
150
151
152
153
154
155
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}

Muyang Li's avatar
Muyang Li committed
156
157
Attention::Attention(int num_heads, int dim_head, Device device)
    : num_heads(num_heads), dim_head(dim_head), force_fp16(false) {
Zhekai Zhang's avatar
Zhekai Zhang committed
158
159
160
161
162
163
164
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

165
166
167
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
168
    const Device device  = qkv.device();
169
170
171
172
173
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
174
175
176
    Tensor q        = reshaped.slice(2, 0, num_heads);
    Tensor k        = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(2, num_heads * 2, num_heads * 3);
177

fengzch's avatar
fengzch committed
178
    Tensor raw_attn_output = call_fa_mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false);
179
180
181
182
183

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
Muyang Li's avatar
Muyang Li committed
184

185
186
187
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
188
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
189
190
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
191
192
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
193
    const Device device  = qkv.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
194
195
196
197
198
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
Muyang Li's avatar
Muyang Li committed
199
    const int pool_tokens   = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
200
201
202
203
204
205
206
207
208
209
210
211
212

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
Muyang Li's avatar
Muyang Li committed
213
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
214
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
215
216
217
            Tensor pool_q = pool_qkv.slice(0, i, i + 1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i + 1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i + 1);
Zhekai Zhang's avatar
Zhekai Zhang committed
218
219
220
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
221

muyangli's avatar
muyangli committed
222
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
Muyang Li's avatar
Muyang Li committed
237
        cu_seqlens_cpu                        = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
Zhekai Zhang's avatar
Zhekai Zhang committed
238
239
240
241
242
243
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

244
245
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
246
        kernels::cast(qkv, tmp);
247
248
249
250
251
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
252
253
254
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
255
256
257
    Tensor q        = reshaped.slice(1, 0, num_heads);
    Tensor k        = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(1, num_heads * 2, num_heads * 3);
Zhekai Zhang's avatar
Zhekai Zhang committed
258
259
260

    spdlog::debug("q,k,v={}", q.shape.str());

fengzch's avatar
fengzch committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    // Tensor raw_attn_output = mha_fwd_block(q,
    //                                        k,
    //                                        v,
    //                                        cu_seqlens,
    //                                        cu_seqlens,
    //                                        POOL_SIZE,
    //                                        POOL_SIZE,
    //                                        headmask_type,
    //                                        {},
    //                                        blockmask,
    //                                        num_tokens,
    //                                        num_tokens,
    //                                        0.0f,
    //                                        pow(q.shape[-1], (-0.5)),
    //                                        false,
    //                                        false,
    //                                        false,
    //                                        -1,
    //                                        -1)
    //                              .front();
    std::cout << "mha_fwd_block not support !!!" << std::endl;
282
283
284
285
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
286
        kernels::cast(raw_attn_output, tmp);
287
288
289
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
304
305
306
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
307
308
309
310
311
312
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
313
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
314
315
316
317
318
319
320
321
322
323
324
325
326
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

327
328
329
330
331
332
333
334
335
336
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

Muyang Li's avatar
Muyang Li committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim,
                                                       int num_attention_heads,
                                                       int attention_head_dim,
                                                       int mlp_ratio,
                                                       bool use_fp4,
                                                       Tensor::ScalarType dtype,
                                                       Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      mlp_hidden_dim(dim * mlp_ratio), norm(dim, dtype, device),
      mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
      mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      norm_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device) {
    registerChildren(norm, "norm")(mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(qkv_proj, "qkv_proj")(norm_q, "norm_q")(
        norm_k, "norm_k")(attn, "attn")(out_proj, "out_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
353
354
355
356
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

fengzch-das's avatar
fengzch-das committed
357
    nvtxRangePushA("FluxSingleTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
358
359
360
361
362
363
364
365
366
367

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

368
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
369
370

    debug("rotary_emb", rotary_emb);
371
372

    if (attnImpl == AttentionImpl::FlashAttention2) {
Muyang Li's avatar
Muyang Li committed
373
374
        Tensor qkv = Tensor::allocate(
            {batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
375
376
377
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

378
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
379
380
381
382
383
384
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             qkv.slice(0, i, i + 1),
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb);
385
        }
386
387
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
388

389
390
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
391
392
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
393
        // assert(batch_size == 1);
394
395
396

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

Muyang Li's avatar
Muyang Li committed
397
398
399
400
401
402
        Tensor q = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
403

404
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
405
406
407
408
409
410
411
412
413
414
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             {},
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb,
                             q.slice(0, i, i + 1),
                             k.slice(0, i, i + 1),
                             v.slice(0, i, i + 1),
                             num_tokens);
415
        }
416
417
418
419
420

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

Muyang Li's avatar
Muyang Li committed
421
422
423
        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head},
                                    norm_hidden_states.scalar_type(),
                                    norm_hidden_states.device());
424
425
426

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

427
428
429
430
        if (batch_size == 1 || num_tokens_pad == num_tokens) {
            attn_output = o.slice(1, 0, num_tokens);
        } else {
            attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
fengzch-das's avatar
fengzch-das committed
431
            checkCUDA(cudaMemcpy2DAsync(attn_output.data_ptr(),
Muyang Li's avatar
Muyang Li committed
432
433
434
435
436
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        o.data_ptr(),
                                        o.stride(0) * o.scalar_size(),
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
437
438
                                        cudaMemcpyDeviceToDevice,
                                        getCurrentCUDAStream()));
439
        }
440
441
442
443
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
444
445
446
447
448
449
450
451
    debug("raw_attn_output", attn_output);

    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
452
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
453
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
454

455
456
    // kernels::mul_add(hidden_states, gate, residual);
    kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
457

fengzch-das's avatar
fengzch-das committed
458
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
459
460
461
462

    return hidden_states;
}

Muyang Li's avatar
Muyang Li committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
JointTransformerBlock::JointTransformerBlock(int dim,
                                             int num_attention_heads,
                                             int attention_head_dim,
                                             bool context_pre_only,
                                             bool use_fp4,
                                             Tensor::ScalarType dtype,
                                             Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      context_pre_only(context_pre_only), norm1(dim, false, dtype, device),
      norm1_context(dim, context_pre_only, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
      norm_k(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
      norm_added_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
      norm2(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
      mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
      mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
      mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) {
    registerChildren(norm1, "norm1")(norm1_context, "norm1_context")(qkv_proj, "qkv_proj")(qkv_proj_context,
                                                                                           "qkv_proj_context")(
        norm_q, "norm_q")(norm_k, "norm_k")(norm_added_q, "norm_added_q")(norm_added_k, "norm_added_k")(attn, "attn")(
        out_proj, "out_proj")(out_proj_context, "out_proj_context")(norm2, "norm2")(norm2_context, "norm2_context")(
        mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(mlp_context_fc1, "mlp_context_fc1")(mlp_context_fc2, "mlp_context_fc2");
Zhekai Zhang's avatar
Zhekai Zhang committed
487
488
489
490
}

// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
Muyang Li's avatar
Muyang Li committed
491
492
493
494
495
496
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
                                                          Tensor encoder_hidden_states,
                                                          Tensor temb,
                                                          Tensor rotary_emb,
                                                          Tensor rotary_emb_context,
                                                          float sparsityRatio) {
Zhekai Zhang's avatar
Zhekai Zhang committed
497
498
499
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
500
    nvtxRangePushA("JointTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
501

fengzch-das's avatar
fengzch-das committed
502
    nvtxRangePushA("AdaNorm");
Zhekai Zhang's avatar
Zhekai Zhang committed
503
504

    int num_tokens_img = hidden_states.shape[1];
505
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
506

Zhekai Zhang's avatar
Zhekai Zhang committed
507
508
509
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

Muyang Li's avatar
Muyang Li committed
510
511
512
513
    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
514
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
515

Muyang Li's avatar
Muyang Li committed
516
    auto norm1_output         = norm1.forward(hidden_states, temb);
Zhekai Zhang's avatar
Zhekai Zhang committed
517
518
519
520
521
522
523
524
525
526
527
528
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
529
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
530

fengzch-das's avatar
fengzch-das committed
531
    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
532

533
534
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
535

536
537
538
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
539

540
541
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
542

543
        {
fengzch-das's avatar
fengzch-das committed
544
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
545

546
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
547

548
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
Muyang Li's avatar
Muyang Li committed
549
550
551
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
552

Muyang Li's avatar
Muyang Li committed
553
554
555
556
            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
557

558
559
560
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Muyang Li's avatar
Muyang Li committed
561
562
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
563

Muyang Li's avatar
Muyang Li committed
564
565
                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
566
                Tensor pool_qkv_context = pool.valid()
Muyang Li's avatar
Muyang Li committed
567
568
569
570
571
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
572

573
574
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
575

576
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
577

Muyang Li's avatar
Muyang Li committed
578
579
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
580
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
581

582
583
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
584

585
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
586

Muyang Li's avatar
Muyang Li committed
587
588
589
590
591
592
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
593
594
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
595

fengzch-das's avatar
fengzch-das committed
596
            nvtxRangePop();
597
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
598

599
600
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
601

602
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
603

fengzch-das's avatar
fengzch-das committed
604
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
605

606
607
608
609
610
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
611

fengzch-das's avatar
fengzch-das committed
612
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
613

614
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
615

616
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
617

618
619
620
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
621

622
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
623

624
        {
fengzch-das's avatar
fengzch-das committed
625
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
626

Muyang Li's avatar
Muyang Li committed
627
628
629
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
630
631
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
632

633
634
            for (int i = 0; i < batch_size; i++) {
                // img first
Muyang Li's avatar
Muyang Li committed
635
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
636
                auto sliceTxt = [&](Tensor x) {
Muyang Li's avatar
Muyang Li committed
637
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
638
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
639

Muyang Li's avatar
Muyang Li committed
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
661
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
662

663
664
665
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
666

fengzch-das's avatar
fengzch-das committed
667
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
668
669
        }

Muyang Li's avatar
Muyang Li committed
670
671
672
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
673

fengzch-das's avatar
fengzch-das committed
674
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
675

676
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
677

fengzch-das's avatar
fengzch-das committed
678
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
679

Muyang Li's avatar
Muyang Li committed
680
681
        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
682
683
684
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
685
686
687
688

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
689
        nvtxRangePushA("o_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
690
691
692

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

693
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
694
695
696

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
697
698
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
699
        } else {
Muyang Li's avatar
Muyang Li committed
700
701
702
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
703
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
704
705
706
707
708
709
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
710
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
711
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
712
        }
muyangli's avatar
muyangli committed
713

Zhekai Zhang's avatar
Zhekai Zhang committed
714
715
716
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
717
718
        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
719
720
721
        debug("img.attn_output", attn_output);

#if 1
722
723
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
724
725
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
726
727
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
728
729
730
731
732
733

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
734
735
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
736
737
738
739
740
741
742
743
744
745
746
747

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
748
749
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
750
751
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
752
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
753
754
755
756
757

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
Muyang Li's avatar
Muyang Li committed
758
        return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
759
760
761
    }

    {
fengzch-das's avatar
fengzch-das committed
762
        nvtxRangePushA("o_proj_context");
Zhekai Zhang's avatar
Zhekai Zhang committed
763
764
765
766
767

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
768
769
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
770
        } else {
Muyang Li's avatar
Muyang Li committed
771
772
773
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
774
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
775
776
777
778
779
780
781
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
782
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
783
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
784
        }
muyangli's avatar
muyangli committed
785

Zhekai Zhang's avatar
Zhekai Zhang committed
786
787
788
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
789
790
791
        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
792
793
794
        debug("context.attn_output", attn_output);

#if 1
795
796
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
797
798
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
799
800
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
801
802
803
804
805
806

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
807
808
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
809
810
811
812
813

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
814

Zhekai Zhang's avatar
Zhekai Zhang committed
815
        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
Muyang Li's avatar
Muyang Li committed
816
817
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
Zhekai Zhang's avatar
Zhekai Zhang committed
818
819
820
821
822
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
823
824
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
825
826
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
827
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
828
829
830
831

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
832
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
833

Muyang Li's avatar
Muyang Li committed
834
    return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
835
}
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
Tensor JointTransformerBlock::get_q_heads(Tensor hidden_states,
                                          Tensor encoder_hidden_states,
                                          Tensor temb,
                                          Tensor rotary_emb,
                                          Tensor rotary_emb_context,
                                          float sparsityRatio) {
    int batch_size     = hidden_states.shape[0];
    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    // Apply AdaNorm.
    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

    Tensor concat = Tensor::allocate(
        {batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());

    const bool blockSparse  = sparsityRatio > 0;
    constexpr int POOL_SIZE = Attention::POOL_SIZE;
    const int poolTokens    = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
    Tensor pool =
        blockSparse
            ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
            : Tensor{};

    // QKV Projection.
    for (int i = 0; i < batch_size; i++) {
        Tensor qkv         = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
        Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
        Tensor pool_qkv    = pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
        Tensor pool_qkv_context =
            pool.valid() ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, poolTokens) : Tensor{};

        qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
        qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                 qkv_context,
                                 pool_qkv_context,
                                 norm_added_q.weight,
                                 norm_added_k.weight,
                                 rotary_emb_context);
    }

    // Extract and return q_heads.
    Tensor q_all = concat.slice(2, 0, num_heads * dim_head);
    Tensor q_img = q_all.slice(1, 0, num_tokens_img);

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;
        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;
        Tensor out       = Tensor::allocate({B, R, C}, t.scalarType, t.device());
fengzch-das's avatar
fengzch-das committed
892
        auto stream      = getCurrentCUDAStream();
893
894
895
896
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
897
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
        }
        return out;
    };
    return make_contiguous(q_img);
}

std::tuple<Tensor, Tensor, Tensor> JointTransformerBlock::forward_ip_adapter_branch(Tensor hidden_states,
                                                                                    Tensor encoder_hidden_states,
                                                                                    Tensor temb,
                                                                                    Tensor rotary_emb,
                                                                                    Tensor rotary_emb_context,
                                                                                    float sparsityRatio) {
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
913
    nvtxRangePushA("JointTransformerBlock");
914

fengzch-das's avatar
fengzch-das committed
915
    nvtxRangePushA("AdaNorm");
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937

    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

    Tensor q_heads;

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;

        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;

        Tensor out = Tensor::allocate({B, R, C}, t.scalarType, t.device());

fengzch-das's avatar
fengzch-das committed
938
        auto stream = getCurrentCUDAStream();
939
940
941
942
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
943
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
        }
        return out;
    };

    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);

    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
967
    nvtxRangePop();
968

fengzch-das's avatar
fengzch-das committed
969
    auto stream = getCurrentCUDAStream();
970
971
972
973
974
975
976
977
978
979
980
981

    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;

    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;

        Tensor concat;
        Tensor pool;

        {
fengzch-das's avatar
fengzch-das committed
982
            nvtxRangePushA("qkv_proj");
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033

            const bool blockSparse = sparsityRatio > 0;

            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());

            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};

            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);

                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
                Tensor pool_qkv_context = pool.valid()
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};

                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);

                debug("rotary_emb", rotary_emb);

                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
                debug("qkv", qkv);

                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);

                debug("rotary_emb_context", rotary_emb_context);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
                debug("qkv_context", qkv_context);
            }

fengzch-das's avatar
fengzch-das committed
1034
            nvtxRangePop();
1035
1036
1037
1038
1039
1040
1041
        }

        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);

        assert(concat.shape[2] == num_heads * dim_head * 3);

fengzch-das's avatar
fengzch-das committed
1042
        nvtxRangePushA("Attention");
1043
1044
1045
1046
1047
1048
1049

        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }

fengzch-das's avatar
fengzch-das committed
1050
        nvtxRangePop();
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069

        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());

        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});

        // IP_adapter
        Tensor q_all = concat.slice(2, 0, num_heads * dim_head); // [B, N_total, dim]

        Tensor q_img = q_all.slice(1, 0, num_tokens_img); // [B, N_img, dim]

        q_heads = make_contiguous(q_img);

    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;

        Tensor concat_q, concat_k, concat_v;

        {
fengzch-das's avatar
fengzch-das committed
1070
            nvtxRangePushA("qkv_proj");
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111

            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);

            for (int i = 0; i < batch_size; i++) {
                // img first
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
                auto sliceTxt = [&](Tensor x) {
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
                };

                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
            }

            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);

fengzch-das's avatar
fengzch-das committed
1112
            nvtxRangePop();
1113
1114
1115
1116
1117
1118
        }

        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());

fengzch-das's avatar
fengzch-das committed
1119
        nvtxRangePushA("Attention");
1120
1121
1122

        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));

fengzch-das's avatar
fengzch-das committed
1123
        nvtxRangePop();
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135

        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});

        q_heads = concat_q;
    } else {
        assert(false);
    }

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
1136
        nvtxRangePushA("o_proj");
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1150
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1151
1152
1153
1154
1155
1156
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1157
                                        cudaMemcpyDeviceToDevice,
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
        debug("img.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1173
1174
        nvtxRangePop();
        nvtxRangePushA("MLP");
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1199
        nvtxRangePop();
1200
1201
1202
1203
1204
1205
1206
1207
1208

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
        return {hidden_states, encoder_hidden_states, q_heads};
    }

    {
fengzch-das's avatar
fengzch-das committed
1209
        nvtxRangePushA("o_proj_context");
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1221
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1222
1223
1224
1225
1226
1227
1228
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1229
                                        cudaMemcpyDeviceToDevice,
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
        debug("context.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1246
1247
        nvtxRangePop();
        nvtxRangePushA("MLP");
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif

        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1274
        nvtxRangePop();
1275
1276
1277
1278

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
1279
    nvtxRangePop();
1280
1281
1282

    return {hidden_states, encoder_hidden_states, q_heads};
}
Zhekai Zhang's avatar
Zhekai Zhang committed
1283

Muyang Li's avatar
Muyang Li committed
1284
1285
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
    : dtype(dtype), offload(offload) {
1286
1287
    CUDADeviceContext model_construction_ctx(device.idx);

Zhekai Zhang's avatar
Zhekai Zhang committed
1288
    for (int i = 0; i < 19; i++) {
Muyang Li's avatar
Muyang Li committed
1289
1290
        transformer_blocks.push_back(
            std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1291
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1292
1293
1294
1295
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1296
1297
    }
    for (int i = 0; i < 38; i++) {
Muyang Li's avatar
Muyang Li committed
1298
1299
        single_transformer_blocks.push_back(
            std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1300
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1301
1302
1303
1304
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1305
1306
1307
    }
}

Muyang Li's avatar
Muyang Li committed
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
Tensor FluxModel::forward(Tensor hidden_states,
                          Tensor encoder_hidden_states,
                          Tensor temb,
                          Tensor rotary_emb_img,
                          Tensor rotary_emb_context,
                          Tensor rotary_emb_single,
                          Tensor controlnet_block_samples,
                          Tensor controlnet_single_block_samples,
                          bool skip_first_layer) {
    const int batch_size           = hidden_states.shape[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
1318
    const Tensor::ScalarType dtype = hidden_states.dtype();
Muyang Li's avatar
Muyang Li committed
1319
    const Device device            = hidden_states.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
1320
1321
1322
1323

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
1324
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
1325

muyangli's avatar
muyangli committed
1326
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
1327

muyangli's avatar
muyangli committed
1328
    auto compute = [&](int layer) {
Muyang Li's avatar
Muyang Li committed
1329
1330
        if (skip_first_layer && size_t(layer) == 0)
            return;
muyangli's avatar
muyangli committed
1331
1332
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
Muyang Li's avatar
Muyang Li committed
1333
1334
            std::tie(hidden_states, encoder_hidden_states) =
                block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
1335
            if (controlnet_block_samples.valid()) {
1336
1337
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1338
1339
                int interval_control =
                    ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1340
1341
1342
1343
1344
1345
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
K's avatar
K committed
1346
            if (residual_callback && layer % 2 == 0) {
K's avatar
K committed
1347
1348
                Tensor residual = residual_callback(hidden_states);
                hidden_states   = kernels::add(hidden_states, residual);
K's avatar
K committed
1349
            }
muyangli's avatar
muyangli committed
1350
1351
1352
1353
1354
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
1355
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
Muyang Li's avatar
Muyang Li committed
1356
1357
1358
                    concat.slice(0, i, i + 1)
                        .slice(1, txt_tokens, txt_tokens + img_tokens)
                        .copy_(hidden_states.slice(0, i, i + 1));
muyangli's avatar
muyangli committed
1359
                }
Muyang Li's avatar
Muyang Li committed
1360
                hidden_states         = concat;
muyangli's avatar
muyangli committed
1361
1362
1363
                encoder_hidden_states = {};
            }

Muyang Li's avatar
Muyang Li committed
1364
            auto &block   = single_transformer_blocks.at(layer - transformer_blocks.size());
muyangli's avatar
muyangli committed
1365
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
1366
            if (controlnet_single_block_samples.valid()) {
1367
1368
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1369
1370
                int interval_control =
                    ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1371
1372
1373
1374
1375
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1376
                slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1377
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Muyang Li's avatar
Muyang Li committed
1378
            }
K's avatar
K committed
1379
1380
1381
            size_t local_layer_idx = layer - transformer_blocks.size();
            if (residual_callback && local_layer_idx % 4 == 0) {
                Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
K's avatar
K committed
1382
1383
1384
                Tensor residual       = residual_callback(callback_input);
                auto slice            = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice                 = kernels::add(slice, residual);
K's avatar
K committed
1385
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Hyunsung Lee's avatar
Hyunsung Lee committed
1386
            }
muyangli's avatar
muyangli committed
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
1410
1411

    return hidden_states;
1412
1413
}

Muyang Li's avatar
Muyang Li committed
1414
1415
1416
1417
1418
1419
1420
1421
1422
std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
                                                    Tensor hidden_states,
                                                    Tensor encoder_hidden_states,
                                                    Tensor temb,
                                                    Tensor rotary_emb_img,
                                                    Tensor rotary_emb_context,
                                                    Tensor controlnet_block_samples,
                                                    Tensor controlnet_single_block_samples) {

1423
1424
1425
1426
1427
1428
1429
1430
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1431
    if (layer < transformer_blocks.size()) {
1432
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
Muyang Li's avatar
Muyang Li committed
1433
1434
1435
1436
1437
            hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    } else {
        std::tie(hidden_states, encoder_hidden_states) =
            transformer_blocks.at(layer - transformer_blocks.size())
                ->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
1438
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
1439
1440
1441
1442
1443

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
1444
1445
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
1446
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Muyang Li's avatar
Muyang Li committed
1447
        int block_index      = layer / interval_control;
Hyunsung Lee's avatar
Hyunsung Lee committed
1448
1449
1450
1451
1452
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
1453
1454
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1455
1456
        int interval_control =
            ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1457
1458
1459
1460
1461
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1462
        slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1463
1464
1465
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

1466
1467
1468
1469
1470
1471
1472
1473
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->releaseLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1474
    return {hidden_states, encoder_hidden_states};
Hyunsung Lee's avatar
Hyunsung Lee committed
1475
1476
}

1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
std::tuple<Tensor, Tensor, Tensor> FluxModel::forward_ip_adapter(size_t layer,
                                                                 Tensor hidden_states,         // [B, Nq, dim]
                                                                 Tensor encoder_hidden_states, // [B, Nt, dim]
                                                                 Tensor temb,
                                                                 Tensor rotary_emb_img, // [B, Nq, dim_head]
                                                                 Tensor rotary_emb_context,
                                                                 Tensor controlnet_block_samples,
                                                                 Tensor controlnet_single_block_samples) {
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

    std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    Tensor ip_query = transformer_blocks.at(layer)->get_q_heads(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);

    if (controlnet_block_samples.valid()) {
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
        int block_index      = layer / interval_control;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    }

    if (offload && layer > 0) {
        transformer_blocks.at(layer)->releaseLazyParams();
    }

    return {hidden_states, encoder_hidden_states, ip_query};
}

1514
1515
1516
1517
1518
1519
1520
1521
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}
Muyang Li's avatar
Muyang Li committed
1522
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor &)> cb) {
K's avatar
K committed
1523
    residual_callback = std::move(cb);
Muyang Li's avatar
Muyang Li committed
1524
}