"vendor/google.golang.org/grpc/codegen.sh" did not exist on "f5f6fd206971a00c8a05e1963f98dd8b494472f6"
layernorm.cpp 3.99 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include "cpu_types.hpp"

namespace {
template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out,
                       const scalar_t *__restrict__ input,
                       const scalar_t *__restrict__ weight, const float epsilon,
                       const int num_tokens, const int hidden_size) {
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);

#pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
    vec_op::FP32Vec8 variance(0.0);
    auto input_p = input + i * hidden_size;
    auto output_p = out + i * hidden_size;
    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
      scalar_vec_t x(input_p + j);
      vec_op::FP32Vec8 fp32_x(x);
      variance = variance + fp32_x * fp32_x;
    }

    float s_variance =
        1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
    vec_op::FP32Vec8 fp32_s_variance(s_variance);

    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
      scalar_vec_t x(input_p + j);
      scalar_vec_t w(weight + j);

      vec_op::FP32Vec8 fp32_x(x);
      vec_op::FP32Vec8 fp32_w(w);

      vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;

      scalar_vec_t out(fp32_out);
      out.save(output_p + j);
    }
  }
}

template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
                                 scalar_t *__restrict__ residual,
                                 const scalar_t *__restrict__ weight,
                                 const float epsilon, const int num_tokens,
                                 const int hidden_size) {
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);

#pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
    vec_op::FP32Vec8 variance(0.0);
    auto input_p = input + i * hidden_size;
    auto residual_p = residual + i * hidden_size;
    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
      scalar_vec_t x(input_p + j);
      scalar_vec_t res(residual_p + j);
      vec_op::FP32Vec8 fp32_x(x);
      vec_op::FP32Vec8 fp32_res(res);

      fp32_x = fp32_x + fp32_res;
      variance = variance + fp32_x * fp32_x;
      scalar_vec_t out(fp32_x);
      out.save(residual_p + j);
    }

    float s_variance =
        1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
    vec_op::FP32Vec8 fp32_s_variance(s_variance);

    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
      scalar_vec_t w(weight + j);
      scalar_vec_t res(residual_p + j);

      vec_op::FP32Vec8 fp32_w(w);
      vec_op::FP32Vec8 fp32_res(res);

      vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;

      scalar_vec_t out(fp32_out);
      out.save(input_p + j);
    }
  }
}
} // namespace

void rms_norm(torch::Tensor &out, torch::Tensor &input,
                  torch::Tensor &weight, float epsilon) {
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;

  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
    CPU_KERNEL_GUARD_IN(rms_norm_impl)
    rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
                      weight.data_ptr<scalar_t>(), epsilon, num_tokens,
                      hidden_size);
    CPU_KERNEL_GUARD_OUT(rms_norm_impl)
  });
}

void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
                            torch::Tensor &weight, float epsilon) {
  int hidden_size = input.size(-1);
  int num_tokens = input.numel() / hidden_size;

  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "fused_add_rms_norm_impl", [&] {
        CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
        fused_add_rms_norm_impl(
            input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
            weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
        CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
      });
}