FluxModel.h 5.52 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
#pragma once

#include "common.h"
#include "Tensor.h"
#include "Module.h"
#include "Linear.h"
#include "layernorm.h"
K's avatar
K committed
8
9
#include <pybind11/functional.h>
namespace pybind11 {
Muyang Li's avatar
Muyang Li committed
10
class function;
K's avatar
K committed
11
}
Zhekai Zhang's avatar
Zhekai Zhang committed
12

13
14
15
16
17
enum class AttentionImpl {
    FlashAttention2 = 0,
    NunchakuFP16,
};

Zhekai Zhang's avatar
Zhekai Zhang committed
18
19
20
class AdaLayerNormZeroSingle : public Module {
public:
    static constexpr bool USE_4BIT = true;
Muyang Li's avatar
Muyang Li committed
21
    using GEMM                     = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
Zhekai Zhang's avatar
Zhekai Zhang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    struct Output {
        Tensor x;
        Tensor gate_msa;
    };

public:
    AdaLayerNormZeroSingle(int dim, Tensor::ScalarType dtype, Device device);
    Output forward(Tensor x, Tensor emb);

public:
    const int dim;

private:
    GEMM linear;
    LayerNorm norm;
};

class AdaLayerNormZero : public Module {
public:
    static constexpr bool USE_4BIT = true;
Muyang Li's avatar
Muyang Li committed
43
    using GEMM                     = std::conditional_t<USE_4BIT, GEMV_AWQ, GEMM_W8A8>;
Zhekai Zhang's avatar
Zhekai Zhang committed
44
45
46
47
48
49
50
51

    struct Output {
        Tensor x;
        Tensor gate_msa;
        Tensor shift_mlp;
        Tensor scale_mlp;
        Tensor gate_mlp;
    };
Muyang Li's avatar
Muyang Li committed
52

Zhekai Zhang's avatar
Zhekai Zhang committed
53
54
55
56
57
58
59
60
61
62
63
64
65
public:
    AdaLayerNormZero(int dim, bool pre_only, Tensor::ScalarType dtype, Device device);
    Output forward(Tensor x, Tensor emb);

public:
    const int dim;
    const bool pre_only;

private:
    GEMM linear;
    LayerNorm norm;
};

66
class Attention : public Module {
Zhekai Zhang's avatar
Zhekai Zhang committed
67
68
public:
    static constexpr int POOL_SIZE = 128;
Hyunsung Lee's avatar
Hyunsung Lee committed
69

Zhekai Zhang's avatar
Zhekai Zhang committed
70
    Attention(int num_heads, int dim_head, Device device);
71
    Tensor forward(Tensor qkv);
Zhekai Zhang's avatar
Zhekai Zhang committed
72
73
    Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);

74
75
    static void setForceFP16(Module *module, bool value);

Zhekai Zhang's avatar
Zhekai Zhang committed
76
77
78
public:
    const int num_heads;
    const int dim_head;
79
    bool force_fp16;
Zhekai Zhang's avatar
Zhekai Zhang committed
80
81
82
83
84
85
86
87
88

private:
    Tensor cu_seqlens_cpu;
    Tensor headmask_type;
};

class FluxSingleTransformerBlock : public Module {
public:
    static constexpr bool USE_4BIT = true;
Muyang Li's avatar
Muyang Li committed
89
90
91
92
93
94
95
96
97
    using GEMM                     = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;

    FluxSingleTransformerBlock(int dim,
                               int num_attention_heads,
                               int attention_head_dim,
                               int mlp_ratio,
                               bool use_fp4,
                               Tensor::ScalarType dtype,
                               Device device);
Zhekai Zhang's avatar
Zhekai Zhang committed
98
99
100
101
102
103
104
105
    Tensor forward(Tensor hidden_states, Tensor temb, Tensor rotary_emb);

public:
    const int dim;
    const int dim_head;
    const int num_heads;
    const int mlp_hidden_dim;

106
107
    AttentionImpl attnImpl = AttentionImpl::FlashAttention2;

Zhekai Zhang's avatar
Zhekai Zhang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
private:
    AdaLayerNormZeroSingle norm;
    GEMM mlp_fc1;
    GEMM mlp_fc2;
    GEMM qkv_proj;
    RMSNorm norm_q, norm_k;
    Attention attn;
    GEMM out_proj;
};

class JointTransformerBlock : public Module {
public:
    static constexpr bool USE_4BIT = true;
Muyang Li's avatar
Muyang Li committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    using GEMM                     = std::conditional_t<USE_4BIT, GEMM_W4A4, GEMM_W8A8>;

    JointTransformerBlock(int dim,
                          int num_attention_heads,
                          int attention_head_dim,
                          bool context_pre_only,
                          bool use_fp4,
                          Tensor::ScalarType dtype,
                          Device device);
    std::tuple<Tensor, Tensor> forward(Tensor hidden_states,
                                       Tensor encoder_hidden_states,
                                       Tensor temb,
                                       Tensor rotary_emb,
                                       Tensor rotary_emb_context,
                                       float sparsityRatio);
Zhekai Zhang's avatar
Zhekai Zhang committed
136
137
138
139
140
141

public:
    const int dim;
    const int dim_head;
    const int num_heads;
    const bool context_pre_only;
142
    AdaLayerNormZero norm1;
Zhekai Zhang's avatar
Zhekai Zhang committed
143

144
145
    AttentionImpl attnImpl = AttentionImpl::FlashAttention2;

Zhekai Zhang's avatar
Zhekai Zhang committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
private:
    AdaLayerNormZero norm1_context;
    GEMM qkv_proj;
    GEMM qkv_proj_context;
    RMSNorm norm_q, norm_k;
    RMSNorm norm_added_q, norm_added_k;
    Attention attn;
    GEMM out_proj;
    GEMM out_proj_context;
    LayerNorm norm2;
    LayerNorm norm2_context;
    GEMM mlp_fc1, mlp_fc2;
    GEMM mlp_context_fc1, mlp_context_fc2;
};

class FluxModel : public Module {
public:
muyangli's avatar
muyangli committed
163
    FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device);
Muyang Li's avatar
Muyang Li committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    Tensor forward(Tensor hidden_states,
                   Tensor encoder_hidden_states,
                   Tensor temb,
                   Tensor rotary_emb_img,
                   Tensor rotary_emb_context,
                   Tensor rotary_emb_single,
                   Tensor controlnet_block_samples,
                   Tensor controlnet_single_block_samples,
                   bool skip_first_layer = false);
    std::tuple<Tensor, Tensor> forward_layer(size_t layer,
                                             Tensor hidden_states,
                                             Tensor encoder_hidden_states,
                                             Tensor temb,
                                             Tensor rotary_emb_img,
                                             Tensor rotary_emb_context,
                                             Tensor controlnet_block_samples,
                                             Tensor controlnet_single_block_samples);
181
182
    void setAttentionImpl(AttentionImpl impl);

Muyang Li's avatar
Muyang Li committed
183
184
    void set_residual_callback(std::function<Tensor(const Tensor &)> cb);

Zhekai Zhang's avatar
Zhekai Zhang committed
185
public:
186
    const Tensor::ScalarType dtype;
Hyunsung Lee's avatar
Hyunsung Lee committed
187

Zhekai Zhang's avatar
Zhekai Zhang committed
188
189
    std::vector<std::unique_ptr<JointTransformerBlock>> transformer_blocks;
    std::vector<std::unique_ptr<FluxSingleTransformerBlock>> single_transformer_blocks;
muyangli's avatar
muyangli committed
190

Muyang Li's avatar
Muyang Li committed
191
192
    std::function<Tensor(const Tensor &)> residual_callback;

muyangli's avatar
muyangli committed
193
194
private:
    bool offload;
Muyang Li's avatar
Muyang Li committed
195
};