layernorm.h 2.13 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#pragma once

#include "common.h"
#include "Tensor.h"
#include "Module.h"

class LayerNorm : public Module {
public:
    LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device);
    Tensor forward(Tensor x);

public:
    const int hidden_size;
    const float eps;

private:
    Tensor weight;
    Tensor bias;
};

class RMSNorm : public Module {
public:
Muyang Li's avatar
Muyang Li committed
23
24
    RMSNorm(int hidden_size, float eps, bool use_quant, Tensor::ScalarType dtype, Device device)
        : use_quant(use_quant), variance_epsilon(eps) {
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
28
29
30
31
32
33
34
35
36
37
        weight = Tensor::allocate({hidden_size}, dtype, device);
        registerParams(weight, "weight");
    }
    Tensor forward(Tensor x);

public:
    const bool use_quant;
    const float variance_epsilon;
    Tensor weight;
};

class RMSNormGeneral {
    friend class LlamaDecoderLayer;
Muyang Li's avatar
Muyang Li committed
38

Zhekai Zhang's avatar
Zhekai Zhang committed
39
public:
Muyang Li's avatar
Muyang Li committed
40
41
    RMSNormGeneral(int hidden_size, bool act_sum, float eps, bool use_per_token_quant, Device device)
        : act_sum(act_sum), use_per_token_quant(use_per_token_quant), variance_epsilon(eps) {
Zhekai Zhang's avatar
Zhekai Zhang committed
42
43
        this->weight = Tensor::ones({hidden_size}, Tensor::FP32, device);
    }
Muyang Li's avatar
Muyang Li committed
44
45
46
47
    void forward(Tensor x,
                 Tensor quantized_hidden_states_buffer,
                 Tensor quantized_scale_buffer,
                 Tensor quantized_sum_buffer) {
Zhekai Zhang's avatar
Zhekai Zhang committed
48
49
50
51
52
53
54
55
        if (act_sum) {
            forward_with_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
        } else {
            forward_wo_act_sum(x, quantized_hidden_states_buffer, quantized_scale_buffer, quantized_sum_buffer);
        }
    }

private:
Muyang Li's avatar
Muyang Li committed
56
57
58
59
60
61
62
63
    void forward_with_act_sum(Tensor x,
                              Tensor quantized_hidden_states_buffer,
                              Tensor quantized_scale_buffer,
                              Tensor quantized_sum_buffer);
    void forward_wo_act_sum(Tensor x,
                            Tensor quantized_hidden_states_buffer,
                            Tensor quantized_scale_buffer,
                            Tensor quantized_sum_buffer);
Zhekai Zhang's avatar
Zhekai Zhang committed
64
65
66
67
68
69

private:
    const bool act_sum;
    const bool use_per_token_quant;
    const float variance_epsilon;
    Tensor weight;
Muyang Li's avatar
Muyang Li committed
70
};