layernorm.cpp 1.87 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
#include "layernorm.h"
#include "kernels/layernorm_kernels.h"

Muyang Li's avatar
Muyang Li committed
4
5
LayerNorm::LayerNorm(int hidden_size, float eps, bool elementwise_affine, Tensor::ScalarType dtype, Device device)
    : hidden_size(hidden_size), eps(eps) {
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7
    if (elementwise_affine) {
        weight = Tensor::allocate({hidden_size}, dtype, device);
Muyang Li's avatar
Muyang Li committed
8
        bias   = Tensor::allocate({hidden_size}, dtype, device);
Zhekai Zhang's avatar
Zhekai Zhang committed
9
10
    }

Muyang Li's avatar
Muyang Li committed
11
    registerParams(weight, "weight")(bias, "bias");
Zhekai Zhang's avatar
Zhekai Zhang committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
}

Tensor LayerNorm::forward(Tensor x) {
    Tensor out = Tensor::empty(x.shape, x.scalar_type(), x.device());
    layernorm_general(out, x, this->weight, this->bias, this->eps);
    return out;
}

Tensor RMSNorm::forward(Tensor x) {
    Tensor out = Tensor::empty(x.shape, use_quant ? Tensor::INT8 : x.scalar_type(), x.device());
    rms_norm(out, x, this->weight, this->variance_epsilon, this->use_quant);
    return out;
}

Muyang Li's avatar
Muyang Li committed
26
27
28
29
30
31
32
33
34
35
36
void RMSNormGeneral::forward_with_act_sum(Tensor x,
                                          Tensor quantized_hidden_states_buffer,
                                          Tensor quantized_scale_buffer,
                                          Tensor quantized_sum_buffer) {
    rms_norm_general_fuse_sum(quantized_hidden_states_buffer,
                              x,
                              this->weight,
                              quantized_sum_buffer,
                              quantized_scale_buffer,
                              variance_epsilon,
                              use_per_token_quant);
Zhekai Zhang's avatar
Zhekai Zhang committed
37
38
}

Muyang Li's avatar
Muyang Li committed
39
40
41
42
43
44
void RMSNormGeneral::forward_wo_act_sum(Tensor x,
                                        Tensor quantized_hidden_states_buffer,
                                        Tensor quantized_scale_buffer,
                                        Tensor quantized_sum_buffer) {
    rms_norm_general(
        quantized_hidden_states_buffer, x, this->weight, quantized_scale_buffer, variance_epsilon, use_per_token_quant);
Zhekai Zhang's avatar
Zhekai Zhang committed
45
}