SanaModel.h 3.02 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
#pragma once

#include "common.h"
#include "Tensor.h"
#include "Linear.h"
#include "layernorm.h"

class SanaLinearAttention : public Module {
public:
10
    SanaLinearAttention(int dim, bool bias, bool pag, bool use_fp4, Tensor::ScalarType dtype, Device device);
muyangli's avatar
muyangli committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

    Tensor forward(Tensor x, Tensor out = {});
    Tensor forward_pag(Tensor x, bool cfg);

public:
    const int dim;
    const int dim_pad;

private:
    GEMM_W4A4 qkv_proj;
    GEMM_W4A4 out_proj;

    std::optional<GEMM_W4A4> pag_to_v;
};

class MultiHeadCrossAttention : public Module {
public:
28
    MultiHeadCrossAttention(int num_heads, int head_dim, bool use_fp4, Tensor::ScalarType dtype, Device device);
muyangli's avatar
muyangli committed
29
30
31
32
33
34
35
36
37

    Tensor forward(Tensor x, Tensor cond, Tensor cu_seqlens_img, Tensor cu_seqlens_txt);

public:
    const int num_heads;
    const int head_dim;

private:
    GEMM_W4A4 q_linear;
Muyang Li's avatar
Muyang Li committed
38
    GEMM_F16 kv_linear;
muyangli's avatar
muyangli committed
39
40
41
42
43
    GEMM_W4A4 out_proj;
};

class SanaGLUMBConv : public Module {
public:
44
    SanaGLUMBConv(int in_features, int hidden_features, bool use_fp4, Tensor::ScalarType dtype, Device device);
muyangli's avatar
muyangli committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    Tensor forward(Tensor x, int H, int W);

public:
    const int in_features;
    const int hidden_features;

private:
    GEMM_W4A4 inverted_conv;
    DWCONV depth_conv;
    GEMM_W4A4 point_conv;
};

class SanaLinearTransformerBlock : public Module {
public:
Muyang Li's avatar
Muyang Li committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    SanaLinearTransformerBlock(int hidden_size,
                               int intermediate_size,
                               int num_cross_attention_heads,
                               bool pag,
                               bool use_fp4,
                               Tensor::ScalarType dtype,
                               Device device);

    Tensor forward(Tensor hidden_states,
                   Tensor encoder_hidden_states,
                   Tensor timestep,
                   Tensor cu_seqlens_img,
                   Tensor cu_seqlens_txt,
                   int H,
                   int W,
                   bool pag,
                   bool cfg);
muyangli's avatar
muyangli committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99

public:
    const int hidden_size;
    const int num_cross_attention_heads;

private:
    Tensor scale_shift_table;
    // Tensor ones;

    SanaLinearAttention attn;
    MultiHeadCrossAttention cross_attn;
    SanaGLUMBConv ff;

    LayerNorm norm1, norm2;
};

struct SanaConfig {
    int num_layers;
    int num_attention_heads;
    int attention_head_dim;
    int num_cross_attention_heads;
    double expand_ratio;
    std::vector<int> pag_layers;
100
    bool use_fp4;
muyangli's avatar
muyangli committed
101
102
103
104
105
};

class SanaModel : public Module {
public:
    SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device);
Muyang Li's avatar
Muyang Li committed
106
107
108
109
110
111
112
113
114
115
    Tensor forward(Tensor hidden_states,
                   Tensor encoder_hidden_states,
                   Tensor timestep,
                   Tensor cu_seqlens_img,
                   Tensor cu_seqlens_txt,
                   int H,
                   int W,
                   bool pag,
                   bool cfg,
                   bool skip_first_layer);
muyangli's avatar
muyangli committed
116
117
118
119
120
121

public:
    const SanaConfig config;

public:
    std::vector<std::unique_ptr<SanaLinearTransformerBlock>> transformer_blocks;
Muyang Li's avatar
Muyang Li committed
122
};