FluxModel.cpp 63 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "FluxModel.h"
#include "kernels/misc_kernels.h"
#include "kernels/gemm_batched.h"
4
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
5
#include "flash_api.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
6
#include "activation.h"
fengzch-das's avatar
fengzch-das committed
7
#include <nvtx3/nvToolsExt.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
8

Muyang Li's avatar
Muyang Li committed
9
10
#include <pybind11/functional.h>

Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
#include <iostream>

using spdlog::fmt_lib::format;
muyangli's avatar
muyangli committed
14
using namespace nunchaku;
Zhekai Zhang's avatar
Zhekai Zhang committed
15
16

Tensor forward_mlp(GEMM_W4A4 &fc1, GEMM_W4A4 &fc2, Tensor norm_hidden_states) {
Muyang Li's avatar
Muyang Li committed
17
18
    Tensor ff_output = fc2.forward_quant(std::get<GEMM_W4A4::QuantizedActivation>(
        fc1.forward(norm_hidden_states, GEMM_W4A4::FuseOptions::GELU_QUANT, &fc2)));
Zhekai Zhang's avatar
Zhekai Zhang committed
19
20
21
22
23
24
25
26
27
    return ff_output;
}

// Tensor forward_mlp(GEMM_W8A8 &fc1, GEMM_W8A8 &fc2, Tensor norm_hidden_states) {
//     Tensor ff_output = fc2.forward(fc1.forward(norm_hidden_states), GEMM_W8A8::FuseOptions::GELU);
//     return ff_output;
// }

Tensor forward_fc(GEMM_W4A4 &fc, Tensor x) {
muyangli's avatar
muyangli committed
28
29
    return fc.forward(x);
    // return std::get<Tensor>(fc.forward(x));
Zhekai Zhang's avatar
Zhekai Zhang committed
30
31
32
33
34
35
}

// Tensor forward_fc(GEMM_W8A8 &fc, Tensor x) {
//     return fc.forward(x);
// }

Muyang Li's avatar
Muyang Li committed
36
37
38
AdaLayerNormZeroSingle::AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device)
    : dim(dim), linear(dim, 3 * dim, true, dtype, device), norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
39
40
41
42
43
44
}

AdaLayerNormZeroSingle::Output AdaLayerNormZeroSingle::forward(Tensor x, Tensor emb) {
    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);
muyangli's avatar
muyangli committed
45
    auto &&[shift_msa, scale_msa, gate_msa] = kernels::split_mod<3>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
46
47
48
49
50
51
    debug("scale_msa", scale_msa);
    debug("shift_msa", shift_msa);

    debug("x", x);
    Tensor norm_x = norm.forward(x);
    debug("norm_x", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
52

53
54
    // kernels::mul_add(norm_x, scale_msa, shift_msa);
    kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
55
56
57
    return Output{norm_x, gate_msa};
}

Muyang Li's avatar
Muyang Li committed
58
59
60
61
AdaLayerNormZero::AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device)
    : dim(dim), pre_only(pre_only), linear(dim, pre_only ? 2 * dim : 6 * dim, true, dtype, device),
      norm(dim, 1e-6, false, dtype, device) {
    registerChildren(linear, "linear")(norm, "norm");
Zhekai Zhang's avatar
Zhekai Zhang committed
62
63
64
65
66
67
68
69
70
71
}

AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
    debug("x", x);

    debug("emb_input", emb);
    emb = linear.forward(Silu::forward(emb));
    debug("emb_linear", emb);

    if (pre_only) {
muyangli's avatar
muyangli committed
72
        auto &&[shift_msa, scale_msa] = kernels::split_mod<2>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
73
74
75
76
77
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

78
79
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
80
        debug("norm_x_scaled", norm_x);
Hyunsung Lee's avatar
Hyunsung Lee committed
81

Zhekai Zhang's avatar
Zhekai Zhang committed
82
83
        return Output{norm_x};
    } else {
muyangli's avatar
muyangli committed
84
        auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(emb);
Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
89
        debug("shift_msa", shift_msa);

        Tensor norm_x = norm.forward(x);
        debug("norm_x", norm_x);

90
91
        // kernels::mul_add(norm_x, scale_msa, shift_msa);
        kernels::mul_add_batch(norm_x, scale_msa, true, 0.0, shift_msa, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
92
93
94
95
96
97
        debug("norm_x_scaled", norm_x);

        return Output{norm_x, gate_msa, shift_mlp, scale_mlp, gate_mlp};
    }
}

Muyang Li's avatar
Muyang Li committed
98
99
Attention::Attention(int num_heads, int dim_head, Device device)
    : num_heads(num_heads), dim_head(dim_head), force_fp16(false) {
Zhekai Zhang's avatar
Zhekai Zhang committed
100
101
102
103
104
105
106
    headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
    for (int i = 0; i < num_heads; i++) {
        headmask_type.data_ptr<int32_t>()[i] = i + 1;
    }
    headmask_type = headmask_type.copy(device);
}

107
108
109
Tensor Attention::forward(Tensor qkv) {
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
110
    const Device device  = qkv.device();
111
112
113
114
115
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    Tensor reshaped = qkv.view({batch_size, num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
116
117
118
    Tensor q        = reshaped.slice(2, 0, num_heads);
    Tensor k        = reshaped.slice(2, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(2, num_heads * 2, num_heads * 3);
119

Muyang Li's avatar
Muyang Li committed
120
    Tensor raw_attn_output = mha_fwd(q, k, v, 0.0f, pow(q.shape[-1], (-0.5)), false, -1, -1, false).front();
121
122
123
124
125

    assert(raw_attn_output.shape[0] == batch_size);
    assert(raw_attn_output.shape[1] == num_tokens);
    assert(raw_attn_output.shape[2] == num_heads);
    assert(raw_attn_output.shape[3] == dim_head);
Muyang Li's avatar
Muyang Li committed
126

127
128
129
    return raw_attn_output.view({batch_size * num_tokens, num_heads, dim_head});
}

Zhekai Zhang's avatar
Zhekai Zhang committed
130
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
131
132
    const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;

Zhekai Zhang's avatar
Zhekai Zhang committed
133
134
    assert(qkv.ndims() == 3);

Muyang Li's avatar
Muyang Li committed
135
    const Device device  = qkv.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
136
137
138
139
140
    const int batch_size = qkv.shape[0];
    const int num_tokens = qkv.shape[1];
    assert(qkv.shape[2] == num_heads * dim_head * 3);

    constexpr int POOL_SIZE = 128;
Muyang Li's avatar
Muyang Li committed
141
    const int pool_tokens   = ceilDiv(num_tokens, POOL_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
142
143
144
145
146
147
148
149
150
151
152
153
154

    Tensor blockmask;

    if (pool_qkv.valid()) {
        assert(pool_qkv.shape[0] == batch_size);
        assert(pool_qkv.shape[1] == pool_tokens);
        assert(pool_qkv.shape[2] == num_heads * dim_head * 3);
    }

    Tensor pool_score = Tensor::allocate({batch_size, num_heads, pool_tokens, pool_tokens}, Tensor::FP32, device);

    if (pool_qkv.valid() && sparsityRatio > 0) {
        pool_qkv = pool_qkv.view({batch_size, pool_tokens, 3, num_heads, dim_head});
Muyang Li's avatar
Muyang Li committed
155
        pool_qkv = pool_qkv.transpose(1, 2).transpose(2, 3); // [batch_size, 3, num_heads, poolTokens, dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
156
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
157
158
159
            Tensor pool_q = pool_qkv.slice(0, i, i + 1).slice(1, 0, 1);
            Tensor pool_k = pool_qkv.slice(0, i, i + 1).slice(1, 1, 2);
            Tensor pool_s = pool_score.slice(0, i, i + 1);
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
            gemm_batched_fp16(pool_q, pool_k, pool_s);
        }
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
163

muyangli's avatar
muyangli committed
164
    blockmask = kernels::topk(pool_score, pool_tokens * (1 - sparsityRatio));
Zhekai Zhang's avatar
Zhekai Zhang committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    if (cu_seqlens_cpu.valid()) {
        if (cu_seqlens_cpu.shape[0] != batch_size + 1) {
            cu_seqlens_cpu = Tensor{};
        } else {
            for (int i = 0; i <= batch_size; i++) {
                if (cu_seqlens_cpu.data_ptr<int32_t>()[i] != num_tokens * i) {
                    cu_seqlens_cpu = Tensor{};
                    break;
                }
            }
        }
    }
    if (!cu_seqlens_cpu.valid()) {
Muyang Li's avatar
Muyang Li committed
179
        cu_seqlens_cpu                        = Tensor::allocate({batch_size + 1}, Tensor::INT32, Device::cpu());
Zhekai Zhang's avatar
Zhekai Zhang committed
180
181
182
183
184
185
        cu_seqlens_cpu.data_ptr<int32_t>()[0] = 0;
        for (int i = 1; i <= batch_size; i++) {
            cu_seqlens_cpu.data_ptr<int32_t>()[i] = cu_seqlens_cpu.data_ptr<int32_t>()[i - 1] + num_tokens;
        }
    }

186
187
    if (cast_fp16) {
        Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
muyangli's avatar
muyangli committed
188
        kernels::cast(qkv, tmp);
189
190
191
192
193
        qkv = tmp;
    }

    debug("qkv", qkv);

Zhekai Zhang's avatar
Zhekai Zhang committed
194
195
196
    Tensor cu_seqlens = cu_seqlens_cpu.copy(device);

    Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
Muyang Li's avatar
Muyang Li committed
197
198
199
    Tensor q        = reshaped.slice(1, 0, num_heads);
    Tensor k        = reshaped.slice(1, num_heads, num_heads * 2);
    Tensor v        = reshaped.slice(1, num_heads * 2, num_heads * 3);
Zhekai Zhang's avatar
Zhekai Zhang committed
200
201
202

    spdlog::debug("q,k,v={}", q.shape.str());

Muyang Li's avatar
Muyang Li committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    Tensor raw_attn_output = mha_fwd_block(q,
                                           k,
                                           v,
                                           cu_seqlens,
                                           cu_seqlens,
                                           POOL_SIZE,
                                           POOL_SIZE,
                                           headmask_type,
                                           {},
                                           blockmask,
                                           num_tokens,
                                           num_tokens,
                                           0.0f,
                                           pow(q.shape[-1], (-0.5)),
                                           false,
                                           false,
                                           false,
                                           -1,
                                           -1)
                                 .front();
Zhekai Zhang's avatar
Zhekai Zhang committed
223

224
225
226
227
    debug("raw_attn_output", raw_attn_output);

    if (cast_fp16) {
        Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
muyangli's avatar
muyangli committed
228
        kernels::cast(raw_attn_output, tmp);
229
230
231
        raw_attn_output = tmp;
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    /**
    Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
        cu_seqlens,
        cu_seqlens,
        concat.shape[1],
        concat.shape[1],
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false,
        true,
        -1, -1,
        false
    ).front();

Hyunsung Lee's avatar
Hyunsung Lee committed
246
247
248
    Tensor raw_attn_output = mha_fwd(q, k, v,
        0.0f,
        pow(q.shape[-1], (-0.5)),
Zhekai Zhang's avatar
Zhekai Zhang committed
249
250
251
252
253
254
        false, -1, -1, false
    ).front();

    Tensor raw_attn_output = mha_varlen_fwd(
        q, k, v,
        cu_seqlens, cu_seqlens,
255
        num_tokens_img + num_tokens_txt, num_tokens_img + num_tokens_txt,
Zhekai Zhang's avatar
Zhekai Zhang committed
256
257
258
259
260
261
262
263
264
265
266
267
268
        0.0f,
        pow(q.shape[-1], (-0.5)),
        false, false, -1, -1, false
    ).front();
    **/

    assert(raw_attn_output.shape[0] == batch_size * num_tokens);
    assert(raw_attn_output.shape[1] == num_heads);
    assert(raw_attn_output.shape[2] == dim_head);

    return raw_attn_output;
}

269
270
271
272
273
274
275
276
277
278
void Attention::setForceFP16(Module *module, bool value) {
    spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");

    module->traverse([&](Module *m) {
        if (Attention *attn = dynamic_cast<Attention *>(m)) {
            attn->force_fp16 = value;
        }
    });
}

Muyang Li's avatar
Muyang Li committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim,
                                                       int num_attention_heads,
                                                       int attention_head_dim,
                                                       int mlp_ratio,
                                                       bool use_fp4,
                                                       Tensor::ScalarType dtype,
                                                       Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      mlp_hidden_dim(dim * mlp_ratio), norm(dim, dtype, device),
      mlp_fc1(dim, mlp_hidden_dim, true, use_fp4, dtype, device),
      mlp_fc2(mlp_hidden_dim, dim, true, use_fp4, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      norm_q(dim_head, 1e-6, false, dtype, device), norm_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device) {
    registerChildren(norm, "norm")(mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(qkv_proj, "qkv_proj")(norm_q, "norm_q")(
        norm_k, "norm_k")(attn, "attn")(out_proj, "out_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
295
296
297
298
}

Tensor FluxSingleTransformerBlock::forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb) {

fengzch-das's avatar
fengzch-das committed
299
    nvtxRangePushA("FluxSingleTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
300
301
302
303
304
305
306
307
308
309

    const int batch_size = hidden_states.shape[0];
    const int num_tokens = hidden_states.shape[1];

    auto &&[norm_hidden_states, gate] = this->norm.forward(hidden_states, temb);
    debug("norm_hidden_states", norm_hidden_states);
    debug("gate", gate);

    Tensor residual = hidden_states;

310
    Tensor attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
311
312

    debug("rotary_emb", rotary_emb);
313
314

    if (attnImpl == AttentionImpl::FlashAttention2) {
Muyang Li's avatar
Muyang Li committed
315
316
        Tensor qkv = Tensor::allocate(
            {batch_size, num_tokens, dim * 3}, norm_hidden_states.scalar_type(), norm_hidden_states.device());
317
318
319
        // qkv_proj.forward(norm_hidden_states, qkv, {});
        // debug("qkv_raw", qkv);

320
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
321
322
323
324
325
326
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             qkv.slice(0, i, i + 1),
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb);
327
        }
328
329
        debug("qkv", qkv);
        // Tensor qkv = forward_fc(qkv_proj, norm_hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
330

331
332
        // attn_output = attn.forward(qkv, {}, 0);
        attn_output = attn.forward(qkv);
333
334
        attn_output = attn_output.reshape({batch_size, num_tokens, num_heads * dim_head});
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
335
        // assert(batch_size == 1);
336
337
338

        const int num_tokens_pad = ceilDiv(num_tokens, 256) * 256;

Muyang Li's avatar
Muyang Li committed
339
340
341
342
343
344
        Tensor q = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor k = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
        Tensor v = Tensor::allocate(
            {batch_size, num_heads, num_tokens_pad, dim_head}, Tensor::FP16, norm_hidden_states.device());
345

346
        for (int i = 0; i < batch_size; i++) {
Muyang Li's avatar
Muyang Li committed
347
348
349
350
351
352
353
354
355
356
            qkv_proj.forward(norm_hidden_states.slice(0, i, i + 1),
                             {},
                             {},
                             norm_q.weight,
                             norm_k.weight,
                             rotary_emb,
                             q.slice(0, i, i + 1),
                             k.slice(0, i, i + 1),
                             v.slice(0, i, i + 1),
                             num_tokens);
357
        }
358
359
360
361
362

        debug("packed_q", q);
        debug("packed_k", k);
        debug("packed_v", v);

Muyang Li's avatar
Muyang Li committed
363
364
365
        Tensor o = Tensor::allocate({batch_size, num_tokens_pad, num_heads * dim_head},
                                    norm_hidden_states.scalar_type(),
                                    norm_hidden_states.device());
366
367
368

        kernels::attention_fp16(q, k, v, o, pow(dim_head, (-0.5)));

369
370
371
372
        if (batch_size == 1 || num_tokens_pad == num_tokens) {
            attn_output = o.slice(1, 0, num_tokens);
        } else {
            attn_output = Tensor::allocate({batch_size, num_tokens, num_heads * dim_head}, o.scalar_type(), o.device());
fengzch-das's avatar
fengzch-das committed
373
            checkCUDA(cudaMemcpy2DAsync(attn_output.data_ptr(),
Muyang Li's avatar
Muyang Li committed
374
375
376
377
378
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        o.data_ptr(),
                                        o.stride(0) * o.scalar_size(),
                                        attn_output.stride(0) * attn_output.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
379
380
                                        cudaMemcpyDeviceToDevice,
                                        getCurrentCUDAStream()));
381
        }
382
383
384
385
    } else {
        assert(false);
    }

Zhekai Zhang's avatar
Zhekai Zhang committed
386
387
388
389
390
391
392
393
    debug("raw_attn_output", attn_output);

    attn_output = forward_fc(out_proj, attn_output);
    debug("attn_output", attn_output);

    Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
    debug("ff_output", ff_output);

muyangli's avatar
muyangli committed
394
    hidden_states = kernels::add(attn_output, ff_output);
Zhekai Zhang's avatar
Zhekai Zhang committed
395
    debug("attn_ff_output", hidden_states);
Hyunsung Lee's avatar
Hyunsung Lee committed
396

397
398
    // kernels::mul_add(hidden_states, gate, residual);
    kernels::mul_add_batch(hidden_states, gate, true, 0.0, residual, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
399

fengzch-das's avatar
fengzch-das committed
400
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
401
402
403
404

    return hidden_states;
}

Muyang Li's avatar
Muyang Li committed
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
JointTransformerBlock::JointTransformerBlock(int dim,
                                             int num_attention_heads,
                                             int attention_head_dim,
                                             bool context_pre_only,
                                             bool use_fp4,
                                             Tensor::ScalarType dtype,
                                             Device device)
    : dim(dim), dim_head(attention_head_dim / num_attention_heads), num_heads(num_attention_heads),
      context_pre_only(context_pre_only), norm1(dim, false, dtype, device),
      norm1_context(dim, context_pre_only, dtype, device), qkv_proj(dim, dim * 3, true, use_fp4, dtype, device),
      qkv_proj_context(dim, dim * 3, true, use_fp4, dtype, device), norm_q(dim_head, 1e-6, false, dtype, device),
      norm_k(dim_head, 1e-6, false, dtype, device), norm_added_q(dim_head, 1e-6, false, dtype, device),
      norm_added_k(dim_head, 1e-6, false, dtype, device),
      attn(num_attention_heads, attention_head_dim / num_attention_heads, device),
      out_proj(dim, dim, true, use_fp4, dtype, device), out_proj_context(dim, dim, true, use_fp4, dtype, device),
      norm2(dim, 1e-6, false, dtype, device), norm2_context(dim, 1e-6, false, dtype, device),
      mlp_fc1(dim, dim * 4, true, use_fp4, dtype, device), mlp_fc2(dim * 4, dim, true, use_fp4, dtype, device),
      mlp_context_fc1(dim, dim * 4, true, use_fp4, dtype, device),
      mlp_context_fc2(dim * 4, dim, true, use_fp4, dtype, device) {
    registerChildren(norm1, "norm1")(norm1_context, "norm1_context")(qkv_proj, "qkv_proj")(qkv_proj_context,
                                                                                           "qkv_proj_context")(
        norm_q, "norm_q")(norm_k, "norm_k")(norm_added_q, "norm_added_q")(norm_added_k, "norm_added_k")(attn, "attn")(
        out_proj, "out_proj")(out_proj_context, "out_proj_context")(norm2, "norm2")(norm2_context, "norm2_context")(
        mlp_fc1, "mlp_fc1")(mlp_fc2, "mlp_fc2")(mlp_context_fc1, "mlp_context_fc1")(mlp_context_fc2, "mlp_context_fc2");
Zhekai Zhang's avatar
Zhekai Zhang committed
429
430
431
432
}

// hidden_states: [Batch, Width * Height, dim]
// encoder_hidden_states: [Batch, Token, dim]
Muyang Li's avatar
Muyang Li committed
433
434
435
436
437
438
std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
                                                          Tensor encoder_hidden_states,
                                                          Tensor temb,
                                                          Tensor rotary_emb,
                                                          Tensor rotary_emb_context,
                                                          float sparsityRatio) {
Zhekai Zhang's avatar
Zhekai Zhang committed
439
440
441
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
442
    nvtxRangePushA("JointTransformerBlock");
Zhekai Zhang's avatar
Zhekai Zhang committed
443

fengzch-das's avatar
fengzch-das committed
444
    nvtxRangePushA("AdaNorm");
Zhekai Zhang's avatar
Zhekai Zhang committed
445
446

    int num_tokens_img = hidden_states.shape[1];
447
    int num_tokens_txt = encoder_hidden_states.shape[1];
Hyunsung Lee's avatar
Hyunsung Lee committed
448

Zhekai Zhang's avatar
Zhekai Zhang committed
449
450
451
    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

Muyang Li's avatar
Muyang Li committed
452
453
454
455
    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
456
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
Zhekai Zhang's avatar
Zhekai Zhang committed
457

Muyang Li's avatar
Muyang Li committed
458
    auto norm1_output         = norm1.forward(hidden_states, temb);
Zhekai Zhang's avatar
Zhekai Zhang committed
459
460
461
462
463
464
465
466
467
468
469
470
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
471
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
472

fengzch-das's avatar
fengzch-das committed
473
    auto stream = getCurrentCUDAStream();
Hyunsung Lee's avatar
Hyunsung Lee committed
474

475
476
    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;
Zhekai Zhang's avatar
Zhekai Zhang committed
477

478
479
480
    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;
481

482
483
        Tensor concat;
        Tensor pool;
Hyunsung Lee's avatar
Hyunsung Lee committed
484

485
        {
fengzch-das's avatar
fengzch-das committed
486
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
487

488
            const bool blockSparse = sparsityRatio > 0;
Hyunsung Lee's avatar
Hyunsung Lee committed
489

490
            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
Muyang Li's avatar
Muyang Li committed
491
492
493
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());
Hyunsung Lee's avatar
Hyunsung Lee committed
494

Muyang Li's avatar
Muyang Li committed
495
496
497
498
            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
499

500
501
502
            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Muyang Li's avatar
Muyang Li committed
503
504
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Hyunsung Lee's avatar
Hyunsung Lee committed
505

Muyang Li's avatar
Muyang Li committed
506
507
                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
508
                Tensor pool_qkv_context = pool.valid()
Muyang Li's avatar
Muyang Li committed
509
510
511
512
513
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};
Hyunsung Lee's avatar
Hyunsung Lee committed
514

515
516
                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
517

518
                debug("rotary_emb", rotary_emb);
Hyunsung Lee's avatar
Hyunsung Lee committed
519

Muyang Li's avatar
Muyang Li committed
520
521
                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
522
                debug("qkv", qkv);
Hyunsung Lee's avatar
Hyunsung Lee committed
523

524
525
                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
526

527
                debug("rotary_emb_context", rotary_emb_context);
Hyunsung Lee's avatar
Hyunsung Lee committed
528

Muyang Li's avatar
Muyang Li committed
529
530
531
532
533
534
                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
535
536
                debug("qkv_context", qkv_context);
            }
Hyunsung Lee's avatar
Hyunsung Lee committed
537

fengzch-das's avatar
fengzch-das committed
538
            nvtxRangePop();
539
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
540

541
542
        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);
Hyunsung Lee's avatar
Hyunsung Lee committed
543

544
        assert(concat.shape[2] == num_heads * dim_head * 3);
Hyunsung Lee's avatar
Hyunsung Lee committed
545

fengzch-das's avatar
fengzch-das committed
546
        nvtxRangePushA("Attention");
Hyunsung Lee's avatar
Hyunsung Lee committed
547

548
549
550
551
552
        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }
Hyunsung Lee's avatar
Hyunsung Lee committed
553

fengzch-das's avatar
fengzch-das committed
554
        nvtxRangePop();
Hyunsung Lee's avatar
Hyunsung Lee committed
555

556
        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
Hyunsung Lee's avatar
Hyunsung Lee committed
557

558
        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
Hyunsung Lee's avatar
Hyunsung Lee committed
559

560
561
562
    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
563

564
        Tensor concat_q, concat_k, concat_v;
Zhekai Zhang's avatar
Zhekai Zhang committed
565

566
        {
fengzch-das's avatar
fengzch-das committed
567
            nvtxRangePushA("qkv_proj");
Hyunsung Lee's avatar
Hyunsung Lee committed
568

Muyang Li's avatar
Muyang Li committed
569
570
571
            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
572
573
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);
Hyunsung Lee's avatar
Hyunsung Lee committed
574

575
576
            for (int i = 0; i < batch_size; i++) {
                // img first
Muyang Li's avatar
Muyang Li committed
577
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
578
                auto sliceTxt = [&](Tensor x) {
Muyang Li's avatar
Muyang Li committed
579
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
580
                };
Hyunsung Lee's avatar
Hyunsung Lee committed
581

Muyang Li's avatar
Muyang Li committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
603
            }
Zhekai Zhang's avatar
Zhekai Zhang committed
604

605
606
607
            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);
Hyunsung Lee's avatar
Hyunsung Lee committed
608

fengzch-das's avatar
fengzch-das committed
609
            nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
610
611
        }

Muyang Li's avatar
Muyang Li committed
612
613
614
        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
615

fengzch-das's avatar
fengzch-das committed
616
        nvtxRangePushA("Attention");
Zhekai Zhang's avatar
Zhekai Zhang committed
617

618
        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
Zhekai Zhang's avatar
Zhekai Zhang committed
619

fengzch-das's avatar
fengzch-das committed
620
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
621

Muyang Li's avatar
Muyang Li committed
622
623
        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
624
625
626
    } else {
        assert(false);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
627
628
629
630

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
631
        nvtxRangePushA("o_proj");
Zhekai Zhang's avatar
Zhekai Zhang committed
632
633
634

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

635
        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Zhekai Zhang's avatar
Zhekai Zhang committed
636
637
638

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
639
640
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
641
        } else {
Muyang Li's avatar
Muyang Li committed
642
643
644
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
645
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
646
647
648
649
650
651
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
652
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
653
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
654
        }
muyangli's avatar
muyangli committed
655

Zhekai Zhang's avatar
Zhekai Zhang committed
656
657
658
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
659
660
        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
661
662
663
        debug("img.attn_output", attn_output);

#if 1
664
665
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
666
667
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
668
669
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
670
671
672
673
674
675

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
676
677
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
678
679
680
681
682
683
684
685
686
687
688
689

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
690
691
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
692
693
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
694
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
695
696
697
698
699

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
Muyang Li's avatar
Muyang Li committed
700
        return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
701
702
703
    }

    {
fengzch-das's avatar
fengzch-das committed
704
        nvtxRangePushA("o_proj_context");
Zhekai Zhang's avatar
Zhekai Zhang committed
705
706
707
708
709

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
Muyang Li's avatar
Muyang Li committed
710
711
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
Zhekai Zhang's avatar
Zhekai Zhang committed
712
        } else {
Muyang Li's avatar
Muyang Li committed
713
714
715
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
716
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
Muyang Li's avatar
Muyang Li committed
717
718
719
720
721
722
723
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
724
                                        cudaMemcpyDeviceToDevice,
Muyang Li's avatar
Muyang Li committed
725
                                        stream));
Zhekai Zhang's avatar
Zhekai Zhang committed
726
        }
muyangli's avatar
muyangli committed
727

Zhekai Zhang's avatar
Zhekai Zhang committed
728
729
730
        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

Muyang Li's avatar
Muyang Li committed
731
732
733
        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
Zhekai Zhang's avatar
Zhekai Zhang committed
734
735
736
        debug("context.attn_output", attn_output);

#if 1
737
738
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
739
740
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
741
742
        nvtxRangePop();
        nvtxRangePushA("MLP");
Zhekai Zhang's avatar
Zhekai Zhang committed
743
744
745
746
747
748

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
749
750
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
751
752
753
754
755

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif
muyangli's avatar
muyangli committed
756

Zhekai Zhang's avatar
Zhekai Zhang committed
757
        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
Muyang Li's avatar
Muyang Li committed
758
759
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
Zhekai Zhang's avatar
Zhekai Zhang committed
760
761
762
763
764
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
765
766
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
767
768
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
769
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
770
771
772
773

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
774
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
775

Muyang Li's avatar
Muyang Li committed
776
    return {hidden_states, encoder_hidden_states};
Zhekai Zhang's avatar
Zhekai Zhang committed
777
}
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
Tensor JointTransformerBlock::get_q_heads(Tensor hidden_states,
                                          Tensor encoder_hidden_states,
                                          Tensor temb,
                                          Tensor rotary_emb,
                                          Tensor rotary_emb_context,
                                          float sparsityRatio) {
    int batch_size     = hidden_states.shape[0];
    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    // Apply AdaNorm.
    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

    Tensor concat = Tensor::allocate(
        {batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());

    const bool blockSparse  = sparsityRatio > 0;
    constexpr int POOL_SIZE = Attention::POOL_SIZE;
    const int poolTokens    = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
    Tensor pool =
        blockSparse
            ? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
            : Tensor{};

    // QKV Projection.
    for (int i = 0; i < batch_size; i++) {
        Tensor qkv         = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
        Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
        Tensor pool_qkv    = pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
        Tensor pool_qkv_context =
            pool.valid() ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, poolTokens) : Tensor{};

        qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
        qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                 qkv_context,
                                 pool_qkv_context,
                                 norm_added_q.weight,
                                 norm_added_k.weight,
                                 rotary_emb_context);
    }

    // Extract and return q_heads.
    Tensor q_all = concat.slice(2, 0, num_heads * dim_head);
    Tensor q_img = q_all.slice(1, 0, num_tokens_img);

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;
        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;
        Tensor out       = Tensor::allocate({B, R, C}, t.scalarType, t.device());
fengzch-das's avatar
fengzch-das committed
834
        auto stream      = getCurrentCUDAStream();
835
836
837
838
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
839
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
        }
        return out;
    };
    return make_contiguous(q_img);
}

std::tuple<Tensor, Tensor, Tensor> JointTransformerBlock::forward_ip_adapter_branch(Tensor hidden_states,
                                                                                    Tensor encoder_hidden_states,
                                                                                    Tensor temb,
                                                                                    Tensor rotary_emb,
                                                                                    Tensor rotary_emb_context,
                                                                                    float sparsityRatio) {
    int batch_size = hidden_states.shape[0];
    assert(encoder_hidden_states.shape[0] == batch_size);

fengzch-das's avatar
fengzch-das committed
855
    nvtxRangePushA("JointTransformerBlock");
856

fengzch-das's avatar
fengzch-das committed
857
    nvtxRangePushA("AdaNorm");
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879

    int num_tokens_img = hidden_states.shape[1];
    int num_tokens_txt = encoder_hidden_states.shape[1];

    assert(hidden_states.shape[2] == dim);
    assert(encoder_hidden_states.shape[2] == dim);

    Tensor q_heads;

    auto make_contiguous = [&](const Tensor &t) {
        int B            = t.shape.dataExtent[0];
        int R            = t.shape.dataExtent[1];
        int C            = t.shape.dataExtent[2];
        size_t E         = t.scalar_size();
        size_t src_pitch = t.stride(1) * E;

        size_t dst_pitch = C * E;
        size_t width     = C * E;
        size_t height    = R;

        Tensor out = Tensor::allocate({B, R, C}, t.scalarType, t.device());

fengzch-das's avatar
fengzch-das committed
880
        auto stream = getCurrentCUDAStream();
881
882
883
884
        for (int b = 0; b < B; ++b) {
            const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
            void *dst       = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
            checkCUDA(
fengzch-das's avatar
fengzch-das committed
885
                cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
        }
        return out;
    };

    spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
                  hidden_states.shape.str(),
                  encoder_hidden_states.shape.str(),
                  temb.shape.str());
    spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);

    auto norm1_output         = norm1.forward(hidden_states, temb);
    auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);

#if 0
    norm1_output.x = hidden_states;
    norm1_context_output.x = encoder_hidden_states;
#endif

    debug("norm_hidden_states", norm1_output.x);
    debug("norm_encoder_hidden_states", norm1_context_output.x);

    constexpr int POOL_SIZE = Attention::POOL_SIZE;

fengzch-das's avatar
fengzch-das committed
909
    nvtxRangePop();
910

fengzch-das's avatar
fengzch-das committed
911
    auto stream = getCurrentCUDAStream();
912
913
914
915
916
917
918
919
920
921
922
923

    int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
    Tensor raw_attn_output;

    if (attnImpl == AttentionImpl::FlashAttention2) {
        num_tokens_img_pad = num_tokens_img;
        num_tokens_txt_pad = num_tokens_txt;

        Tensor concat;
        Tensor pool;

        {
fengzch-das's avatar
fengzch-das committed
924
            nvtxRangePushA("qkv_proj");
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975

            const bool blockSparse = sparsityRatio > 0;

            const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
            concat               = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
                                      norm1_output.x.scalar_type(),
                                      norm1_output.x.device());

            pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
                                                  norm1_output.x.scalar_type(),
                                                  norm1_output.x.device())
                               : Tensor{};

            for (int i = 0; i < batch_size; i++) {
                // img first
                Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
                Tensor qkv_context =
                    concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);

                Tensor pool_qkv =
                    pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
                Tensor pool_qkv_context = pool.valid()
                                              ? pool.slice(0, i, i + 1)
                                                    .slice(1,
                                                           num_tokens_img / POOL_SIZE,
                                                           num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
                                              : Tensor{};

                // qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
                // debug("qkv_raw", qkv);

                debug("rotary_emb", rotary_emb);

                qkv_proj.forward(
                    norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
                debug("qkv", qkv);

                // qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
                // debug("qkv_context_raw", qkv_context);

                debug("rotary_emb_context", rotary_emb_context);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         qkv_context,
                                         pool_qkv_context,
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context);
                debug("qkv_context", qkv_context);
            }

fengzch-das's avatar
fengzch-das committed
976
            nvtxRangePop();
977
978
979
980
981
982
983
        }

        spdlog::debug("concat={}", concat.shape.str());
        debug("concat", concat);

        assert(concat.shape[2] == num_heads * dim_head * 3);

fengzch-das's avatar
fengzch-das committed
984
        nvtxRangePushA("Attention");
985
986
987
988
989
990
991

        if (pool.valid()) {
            raw_attn_output = attn.forward(concat, pool, sparsityRatio);
        } else {
            raw_attn_output = attn.forward(concat);
        }

fengzch-das's avatar
fengzch-das committed
992
        nvtxRangePop();
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011

        spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());

        raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});

        // IP_adapter
        Tensor q_all = concat.slice(2, 0, num_heads * dim_head); // [B, N_total, dim]

        Tensor q_img = q_all.slice(1, 0, num_tokens_img); // [B, N_img, dim]

        q_heads = make_contiguous(q_img);

    } else if (attnImpl == AttentionImpl::NunchakuFP16) {
        num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
        num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;

        Tensor concat_q, concat_k, concat_v;

        {
fengzch-das's avatar
fengzch-das committed
1012
            nvtxRangePushA("qkv_proj");
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053

            concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
                                        Tensor::FP16,
                                        norm1_output.x.device());
            concat_k = Tensor::empty_like(concat_q);
            concat_v = Tensor::empty_like(concat_q);

            for (int i = 0; i < batch_size; i++) {
                // img first
                auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
                auto sliceTxt = [&](Tensor x) {
                    return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
                };

                qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
                                 {},
                                 {},
                                 norm_q.weight,
                                 norm_k.weight,
                                 rotary_emb,
                                 sliceImg(concat_q),
                                 sliceImg(concat_k),
                                 sliceImg(concat_v),
                                 num_tokens_img);

                qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
                                         {},
                                         {},
                                         norm_added_q.weight,
                                         norm_added_k.weight,
                                         rotary_emb_context,
                                         sliceTxt(concat_q),
                                         sliceTxt(concat_k),
                                         sliceTxt(concat_v),
                                         num_tokens_txt);
            }

            debug("concat_q", concat_q);
            debug("concat_k", concat_k);
            debug("concat_v", concat_v);

fengzch-das's avatar
fengzch-das committed
1054
            nvtxRangePop();
1055
1056
1057
1058
1059
1060
        }

        raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
                                           norm1_output.x.scalar_type(),
                                           norm1_output.x.device());

fengzch-das's avatar
fengzch-das committed
1061
        nvtxRangePushA("Attention");
1062
1063
1064

        kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));

fengzch-das's avatar
fengzch-das committed
1065
        nvtxRangePop();
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077

        raw_attn_output =
            raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});

        q_heads = concat_q;
    } else {
        assert(false);
    }

    debug("raw_attn_output", raw_attn_output);

    {
fengzch-das's avatar
fengzch-das committed
1078
        nvtxRangePushA("o_proj");
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;

        // raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split =
                raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1092
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1093
1094
1095
1096
1097
1098
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1099
                                        cudaMemcpyDeviceToDevice,
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("img.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
        debug("img.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
        hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1115
1116
        nvtxRangePop();
        nvtxRangePushA("MLP");
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140

        spdlog::debug("attn_output={}", hidden_states.shape.str());

        Tensor norm_hidden_states = norm2.forward(hidden_states);
        debug("scale_mlp", scale_mlp);
        debug("shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        Tensor norm_hidden_states = hidden_states;
#endif

        // Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
        debug("img.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
        debug("img.ff_output", ff_output);

        debug("gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
        hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1141
        nvtxRangePop();
1142
1143
1144
1145
1146
1147
1148
1149
1150

        spdlog::debug("ff_output={}", hidden_states.shape.str());
    }

    if (context_pre_only) {
        return {hidden_states, encoder_hidden_states, q_heads};
    }

    {
fengzch-das's avatar
fengzch-das committed
1151
        nvtxRangePushA("o_proj_context");
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

        auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;

        Tensor raw_attn_output_split;
        if (batch_size == 1) {
            raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
                                        .reshape({batch_size, num_tokens_txt, num_heads * dim_head});
        } else {
            raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
                                                     raw_attn_output.scalar_type(),
                                                     raw_attn_output.device());
fengzch-das's avatar
fengzch-das committed
1163
            checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
1164
1165
1166
1167
1168
1169
1170
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
                                                                               raw_attn_output_split.scalar_size(),
                                        (num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
                                            raw_attn_output.scalar_size(),
                                        num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
                                        batch_size,
fengzch-das's avatar
fengzch-das committed
1171
                                        cudaMemcpyDeviceToDevice,
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
                                        stream));
        }

        spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
        debug("context.raw_attn_output_split", raw_attn_output_split);

        Tensor attn_output =
            forward_fc(out_proj_context,
                       raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
        debug("context.attn_output", attn_output);

#if 1
        // kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
        kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(attn_output);

fengzch-das's avatar
fengzch-das committed
1188
1189
        nvtxRangePop();
        nvtxRangePushA("MLP");
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215

        spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());

        Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
        debug("c_scale_mlp", scale_mlp);
        debug("c_shift_mlp", shift_mlp);
        // kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
        kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);

        spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
        auto norm_hidden_states = encoder_hidden_states;
#endif

        // Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
        // Tensor ff_output =
        // mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
        debug("context.ff_input", norm_hidden_states);
        Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
        debug("context.ff_output", ff_output);

        debug("c_gate_mlp", gate_mlp);
        // kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
        kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
        encoder_hidden_states = std::move(ff_output);

fengzch-das's avatar
fengzch-das committed
1216
        nvtxRangePop();
1217
1218
1219
1220

        spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
    }

fengzch-das's avatar
fengzch-das committed
1221
    nvtxRangePop();
1222
1223
1224

    return {hidden_states, encoder_hidden_states, q_heads};
}
Zhekai Zhang's avatar
Zhekai Zhang committed
1225

Muyang Li's avatar
Muyang Li committed
1226
1227
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
    : dtype(dtype), offload(offload) {
1228
1229
    CUDADeviceContext model_construction_ctx(device.idx);

Zhekai Zhang's avatar
Zhekai Zhang committed
1230
    for (int i = 0; i < 19; i++) {
Muyang Li's avatar
Muyang Li committed
1231
1232
        transformer_blocks.push_back(
            std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1233
        registerChildren(*transformer_blocks.back(), format("transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1234
1235
1236
1237
        if (offload && i > 0) { // don't offload first block
            transformer_blocks.back()->setLazyLoad(true);
            transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1238
1239
    }
    for (int i = 0; i < 38; i++) {
Muyang Li's avatar
Muyang Li committed
1240
1241
        single_transformer_blocks.push_back(
            std::make_unique<FluxSingleTransformerBlock>(3072, 24, 3072, 4, use_fp4, dtype, device));
Zhekai Zhang's avatar
Zhekai Zhang committed
1242
        registerChildren(*single_transformer_blocks.back(), format("single_transformer_blocks.{}", i));
muyangli's avatar
muyangli committed
1243
1244
1245
1246
        if (offload) {
            single_transformer_blocks.back()->setLazyLoad(true);
            single_transformer_blocks.back()->releaseLazyParams();
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
1247
1248
1249
    }
}

Muyang Li's avatar
Muyang Li committed
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
Tensor FluxModel::forward(Tensor hidden_states,
                          Tensor encoder_hidden_states,
                          Tensor temb,
                          Tensor rotary_emb_img,
                          Tensor rotary_emb_context,
                          Tensor rotary_emb_single,
                          Tensor controlnet_block_samples,
                          Tensor controlnet_single_block_samples,
                          bool skip_first_layer) {
    const int batch_size           = hidden_states.shape[0];
Zhekai Zhang's avatar
Zhekai Zhang committed
1260
    const Tensor::ScalarType dtype = hidden_states.dtype();
Muyang Li's avatar
Muyang Li committed
1261
    const Device device            = hidden_states.device();
Zhekai Zhang's avatar
Zhekai Zhang committed
1262
1263
1264
1265

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

muyangli's avatar
muyangli committed
1266
    const int numLayers = transformer_blocks.size() + single_transformer_blocks.size();
Zhekai Zhang's avatar
Zhekai Zhang committed
1267

muyangli's avatar
muyangli committed
1268
    Tensor concat;
Zhekai Zhang's avatar
Zhekai Zhang committed
1269

muyangli's avatar
muyangli committed
1270
    auto compute = [&](int layer) {
Muyang Li's avatar
Muyang Li committed
1271
1272
        if (skip_first_layer && size_t(layer) == 0)
            return;
muyangli's avatar
muyangli committed
1273
1274
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
Muyang Li's avatar
Muyang Li committed
1275
1276
            std::tie(hidden_states, encoder_hidden_states) =
                block->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Hyunsung Lee's avatar
Hyunsung Lee committed
1277
            if (controlnet_block_samples.valid()) {
1278
1279
                const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1280
1281
                int interval_control =
                    ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1282
1283
1284
1285
1286
1287
                int block_index = layer / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_block_samples;

                hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
            }
K's avatar
K committed
1288
            if (residual_callback && layer % 2 == 0) {
K's avatar
K committed
1289
1290
                Tensor residual = residual_callback(hidden_states);
                hidden_states   = kernels::add(hidden_states, residual);
K's avatar
K committed
1291
            }
muyangli's avatar
muyangli committed
1292
1293
1294
1295
1296
        } else {
            if (size_t(layer) == transformer_blocks.size()) {
                // txt first, same as diffusers
                concat = Tensor::allocate({batch_size, txt_tokens + img_tokens, 3072}, dtype, device);
                for (int i = 0; i < batch_size; i++) {
1297
                    concat.slice(0, i, i + 1).slice(1, 0, txt_tokens).copy_(encoder_hidden_states.slice(0, i, i + 1));
Muyang Li's avatar
Muyang Li committed
1298
1299
1300
                    concat.slice(0, i, i + 1)
                        .slice(1, txt_tokens, txt_tokens + img_tokens)
                        .copy_(hidden_states.slice(0, i, i + 1));
muyangli's avatar
muyangli committed
1301
                }
Muyang Li's avatar
Muyang Li committed
1302
                hidden_states         = concat;
muyangli's avatar
muyangli committed
1303
1304
1305
                encoder_hidden_states = {};
            }

Muyang Li's avatar
Muyang Li committed
1306
            auto &block   = single_transformer_blocks.at(layer - transformer_blocks.size());
muyangli's avatar
muyangli committed
1307
            hidden_states = block->forward(hidden_states, temb, rotary_emb_single);
Hyunsung Lee's avatar
Hyunsung Lee committed
1308
            if (controlnet_single_block_samples.valid()) {
1309
1310
                const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1311
1312
                int interval_control =
                    ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1313
1314
1315
1316
1317
                int block_index = (layer - transformer_blocks.size()) / interval_control;
                // Xlabs ControlNet
                // block_index = layer % num_controlnet_single_block_samples

                auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1318
                slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1319
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Muyang Li's avatar
Muyang Li committed
1320
            }
K's avatar
K committed
1321
1322
1323
            size_t local_layer_idx = layer - transformer_blocks.size();
            if (residual_callback && local_layer_idx % 4 == 0) {
                Tensor callback_input = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
K's avatar
K committed
1324
1325
1326
                Tensor residual       = residual_callback(callback_input);
                auto slice            = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
                slice                 = kernels::add(slice, residual);
K's avatar
K committed
1327
                hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
Hyunsung Lee's avatar
Hyunsung Lee committed
1328
            }
muyangli's avatar
muyangli committed
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
        }
    };
    auto load = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->loadLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->loadLazyParams();
        }
    };
    auto unload = [&](int layer) {
        if (size_t(layer) < transformer_blocks.size()) {
            auto &block = transformer_blocks.at(layer);
            block->releaseLazyParams();
        } else {
            auto &block = single_transformer_blocks.at(layer - transformer_blocks.size());
            block->releaseLazyParams();
        }
    };

    LayerOffloadHelper helper(this->offload, numLayers, compute, load, unload);
    helper.run();
Zhekai Zhang's avatar
Zhekai Zhang committed
1352
1353

    return hidden_states;
1354
1355
}

Muyang Li's avatar
Muyang Li committed
1356
1357
1358
1359
1360
1361
1362
1363
1364
std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
                                                    Tensor hidden_states,
                                                    Tensor encoder_hidden_states,
                                                    Tensor temb,
                                                    Tensor rotary_emb_img,
                                                    Tensor rotary_emb_context,
                                                    Tensor controlnet_block_samples,
                                                    Tensor controlnet_single_block_samples) {

1365
1366
1367
1368
1369
1370
1371
1372
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1373
    if (layer < transformer_blocks.size()) {
1374
        std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
Muyang Li's avatar
Muyang Li committed
1375
1376
1377
1378
1379
            hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    } else {
        std::tie(hidden_states, encoder_hidden_states) =
            transformer_blocks.at(layer - transformer_blocks.size())
                ->forward(hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
1380
    }
Hyunsung Lee's avatar
Hyunsung Lee committed
1381
1382
1383
1384
1385

    const int txt_tokens = encoder_hidden_states.shape[1];
    const int img_tokens = hidden_states.shape[1];

    if (layer < transformer_blocks.size() && controlnet_block_samples.valid()) {
1386
1387
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

Hyunsung Lee's avatar
Hyunsung Lee committed
1388
        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
Muyang Li's avatar
Muyang Li committed
1389
        int block_index      = layer / interval_control;
Hyunsung Lee's avatar
Hyunsung Lee committed
1390
1391
1392
1393
1394
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_block_samples;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    } else if (layer >= transformer_blocks.size() && controlnet_single_block_samples.valid()) {
1395
1396
        const int num_controlnet_single_block_samples = controlnet_single_block_samples.shape[0];

Muyang Li's avatar
Muyang Li committed
1397
1398
        int interval_control =
            ceilDiv(single_transformer_blocks.size(), static_cast<size_t>(num_controlnet_single_block_samples));
Hyunsung Lee's avatar
Hyunsung Lee committed
1399
1400
1401
1402
1403
        int block_index = (layer - transformer_blocks.size()) / interval_control;
        // Xlabs ControlNet
        // block_index = layer % num_controlnet_single_block_samples

        auto slice = hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens);
Muyang Li's avatar
Muyang Li committed
1404
        slice      = kernels::add(slice, controlnet_single_block_samples[block_index]);
Hyunsung Lee's avatar
Hyunsung Lee committed
1405
1406
1407
        hidden_states.slice(1, txt_tokens, txt_tokens + img_tokens).copy_(slice);
    }

1408
1409
1410
1411
1412
1413
1414
1415
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->releaseLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->releaseLazyParams();
        }
    }

Muyang Li's avatar
Muyang Li committed
1416
    return {hidden_states, encoder_hidden_states};
Hyunsung Lee's avatar
Hyunsung Lee committed
1417
1418
}

1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
std::tuple<Tensor, Tensor, Tensor> FluxModel::forward_ip_adapter(size_t layer,
                                                                 Tensor hidden_states,         // [B, Nq, dim]
                                                                 Tensor encoder_hidden_states, // [B, Nt, dim]
                                                                 Tensor temb,
                                                                 Tensor rotary_emb_img, // [B, Nq, dim_head]
                                                                 Tensor rotary_emb_context,
                                                                 Tensor controlnet_block_samples,
                                                                 Tensor controlnet_single_block_samples) {
    if (offload && layer > 0) {
        if (layer < transformer_blocks.size()) {
            transformer_blocks.at(layer)->loadLazyParams();
        } else {
            transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
        }
    }

    std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
    Tensor ip_query = transformer_blocks.at(layer)->get_q_heads(
        hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);

    if (controlnet_block_samples.valid()) {
        const int num_controlnet_block_samples = controlnet_block_samples.shape[0];

        int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
        int block_index      = layer / interval_control;

        hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
    }

    if (offload && layer > 0) {
        transformer_blocks.at(layer)->releaseLazyParams();
    }

    return {hidden_states, encoder_hidden_states, ip_query};
}

1456
1457
1458
1459
1460
1461
1462
1463
void FluxModel::setAttentionImpl(AttentionImpl impl) {
    for (auto &&block : this->transformer_blocks) {
        block->attnImpl = impl;
    }
    for (auto &&block : this->single_transformer_blocks) {
        block->attnImpl = impl;
    }
}
Muyang Li's avatar
Muyang Li committed
1464
void FluxModel::set_residual_callback(std::function<Tensor(const Tensor &)> cb) {
K's avatar
K committed
1465
    residual_callback = std::move(cb);
Muyang Li's avatar
Muyang Li committed
1466
}