activation.cpp 5.42 KB
Newer Older
1
2
3
#include "cpu_types.hpp"

namespace {
4
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
5
          bool is_gated>
6
7
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
                       scalar_t* __restrict__ output) {
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();

  TORCH_CHECK(d % VEC_ELEM_NUM == 0);

#pragma omp parallel for
  for (int i = 0; i < num_tokens; ++i) {
    for (int j = 0; j < d; j += VEC_ELEM_NUM) {
      int start = i * d;
      if constexpr (is_gated) {
        start *= 2;
      }

      const scalar_vec_t x(input + start + j);
      const vec_op::FP32Vec8 f32_x(x);
      vec_op::FP32Vec8 f32_ans = func(f32_x);

      if constexpr (is_gated) {
        const scalar_vec_t y(input + start + d + j);
        const vec_op::FP32Vec8 f32_y(y);
        f32_ans = f32_y * f32_ans;
      }

      const scalar_vec_t result(f32_ans);
      result.save(output + i * d + j);
    }
  }
}

37
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
38
39
40
41
42
  const vec_op::FP32Vec8 zeros(0.0);
  const vec_op::FP32Vec8 ones(1.0);
  return x / (ones + (zeros - x).exp());
}

43
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
44
45
46
47
48
49
50
51
52
  const vec_op::FP32Vec8 ones(1.0);
  const vec_op::FP32Vec8 w1(0.79788456f);
  const vec_op::FP32Vec8 w2(0.044715f);
  const vec_op::FP32Vec8 w3(0.5);
  const vec_op::FP32Vec8 x3 = x * x * x;
  const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
  return w3 * x * (ones + t);
}

53
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
54
55
56
57
58
59
60
61
  const vec_op::FP32Vec8 ones(1.0);
  const vec_op::FP32Vec8 w1(0.79788456f);
  const vec_op::FP32Vec8 w2(0.044715f);
  const vec_op::FP32Vec8 w3(0.5);
  const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
  return w3 * x * (ones + t);
}

62
63
64
65
66
67
68
FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
  const vec_op::FP32Vec8 zeros(0.0);
  const vec_op::FP32Vec8 ones(1.0);
  const vec_op::FP32Vec8 w1(1.702f);
  return x / (ones + (zeros - w1 * x).exp());
}

69
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
70
71
72
73
74
75
  const vec_op::FP32Vec8 ones(1.0);
  const vec_op::FP32Vec8 w1(M_SQRT1_2);
  const vec_op::FP32Vec8 w2(0.5);
  return x * w2 * (ones + (x * w1).er());
}

76
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
77
78
79
80
81
82
83
84
  const vec_op::FP32Vec8 ones(1.0);
  const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
  const vec_op::FP32Vec8 w2(0.5);
  const vec_op::FP32Vec8 w3(0.044715);
  const vec_op::FP32Vec8 x_3 = x * x * x;
  const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
  return x * w2 * (ones + inner.tanh());
}
85
};  // namespace
86

87
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
88
89
90
  int num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1) / 2;

91
92
93
94
95
96
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
    CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
    activation_kernel<scalar_t, silu_act, true>(
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
    CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
  });
97
98
}

99
100
void gelu_and_mul(torch::Tensor& out,    // [..., d]
                  torch::Tensor& input)  // [..., 2 * d]
101
102
103
104
{
  int num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1) / 2;

105
106
107
108
109
110
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
    CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
    activation_kernel<scalar_t, gelu_act, true>(
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
    CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
  });
111
112
}

113
114
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
                       torch::Tensor& input)  // [..., 2 * d]
115
116
117
118
119
120
121
122
123
124
125
126
127
128
{
  int num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1) / 2;

  VLLM_DISPATCH_FLOATING_TYPES(
      input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
        CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
        activation_kernel<scalar_t, gelu_tanh_act, true>(
            num_tokens, d, input.data_ptr<scalar_t>(),
            out.data_ptr<scalar_t>());
        CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
      });
}

129
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
130
131
132
133
134
135
136
137
138
139
140
  int num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1);

  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
    CPU_KERNEL_GUARD_IN(gelu_new_impl)
    activation_kernel<scalar_t, gelu_new_act, false>(
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
    CPU_KERNEL_GUARD_OUT(gelu_new_impl)
  });
}

141
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
142
143
144
145
146
147
148
149
150
151
  int num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1);

  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
    CPU_KERNEL_GUARD_IN(gelu_fast_impl)
    activation_kernel<scalar_t, gelu_fast_act, false>(
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
    CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
  });
}
152
153
154
155
156
157
158
159
160
161
162
163

void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
  int num_tokens = input.numel() / input.size(-1);
  int d = input.size(-1);

  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
    CPU_KERNEL_GUARD_IN(gelu_quick_impl)
    activation_kernel<scalar_t, gelu_quick_act, false>(
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
    CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
  });
}