"lightx2v/server/README.md" did not exist on "fab43a07e6b48d614a89fe56c1ab2e63a5f9356a"
layernorm.h 1.92 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#pragma once

#include "common.h"
#include "Tensor.h"
#include "Module.h"

class LayerNorm : public Module {
public:
    LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device);
    Tensor forward(Tensor x);

public:
    const int hidden_size;
    const float eps;

private:
    Tensor weight;
    Tensor bias;
};

class RMSNorm : public Module {
public:
    RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device) : 
        use_quant(use_quant), variance_epsilon(eps)
    {
        weight = Tensor::allocate({hidden_size}, dtype, device);
        registerParams(weight, "weight");
    }
    Tensor forward(Tensor x);

public:
    const bool use_quant;
    const float variance_epsilon;
    Tensor weight;
};

class RMSNormGeneral {
    friend class LlamaDecoderLayer;
public:
    RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device) 
        : act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps)
    {
        this->weight = Tensor::ones({hidden_size}, Tensor::FP32, device);
    }
    void forward(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer) {
        if (act_sum) {
            forward_with_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
        } else {
            forward_wo_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
        }
    }

private:
    void forward_with_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);
    void forward_wo_act_sum(Tensor x, Tensor quantized_hidden_states_buffer, Tensor quantized_scale_buffer, Tensor quantized_sum_buffer);

private:
    const bool act_sum;
    const bool use_per_token_quant;
    const float variance_epsilon;
    Tensor weight;
};