"docs/models/extensions/tensorizer.md" did not exist on "32aa2059addd97be1afce7a199d228191710c294"
pos_encoding.cpp 7.81 KB
Newer Older
1
2
3
4
5
6

#include "cpu_types.hpp"

namespace {
template <typename scalar_t>
void rotary_embedding_impl(
7
8
9
10
11
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           /// [batch_size, seq_len, num_heads,
                                   /// head_size] or [num_tokens, num_heads,
                                   /// head_size]
12
13
    scalar_t* __restrict__ key,  // nullptr (optional) or
                                 // [batch_size, seq_len, num_kv_heads,
14
15
16
17
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
                                                 // 2]
18
19
20
21
22
23
24
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
    const int num_heads, const int num_kv_heads, const int head_size,
    const int num_tokens) {
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();

  const int embed_dim = rot_dim / 2;
25
26
  bool flag = (embed_dim % VEC_ELEM_NUM == 0);
  const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
27

28
29
30
31
32
33
34
  auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
                          scalar_t* qk) {
    int j = 0;
    for (; j < loop_upper; j += VEC_ELEM_NUM) {
      const int rot_offset = j;
      const int x_index = rot_offset;
      const int y_index = embed_dim + rot_offset;
35

36
37
      const int64_t out_x = token_head + x_index;
      const int64_t out_y = token_head + y_index;
38

39
40
      const scalar_vec_t cos(cache_ptr + x_index);
      const scalar_vec_t sin(cache_ptr + y_index);
41

42
43
      const scalar_vec_t q_x(qk + out_x);
      const scalar_vec_t q_y(qk + out_y);
44

45
46
      vec_op::FP32Vec8 fp32_cos(cos);
      vec_op::FP32Vec8 fp32_sin(sin);
47

48
49
      vec_op::FP32Vec8 fp32_q_x(q_x);
      vec_op::FP32Vec8 fp32_q_y(q_y);
50

51
52
      auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
      scalar_vec_t(out1).save(qk + out_x);
53

54
55
      auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
      scalar_vec_t(out2).save(qk + out_y);
56
    }
57
58
59
60
    if (!flag) {
      for (; j < embed_dim; ++j) {
        const int x_index = j;
        const int y_index = embed_dim + j;
61
62
63
64

        const int64_t out_x = token_head + x_index;
        const int64_t out_y = token_head + y_index;

65
66
        const float fp32_cos = cache_ptr[x_index];
        const float fp32_sin = cache_ptr[y_index];
67

68
69
        const float fp32_q_x = qk[out_x];
        const float fp32_q_y = qk[out_y];
70

71
72
73
74
75
        qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
        qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
      }
    }
  };
76

77
78
79
80
#pragma omp parallel for
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    int64_t pos = positions[token_idx];
    const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
81

82
83
84
85
86
87
88
    for (int i = 0; i < num_heads; ++i) {
      const int head_idx = i;
      const int64_t token_head =
          token_idx * query_stride + head_idx * head_size;
      compute_loop(token_head, cache_ptr, query);
    }

89
90
91
92
93
94
95
    if (key != nullptr) {
      for (int i = 0; i < num_kv_heads; ++i) {
        const int head_idx = i;
        const int64_t token_head =
            token_idx * key_stride + head_idx * head_size;
        compute_loop(token_head, cache_ptr, key);
      }
96
97
98
99
100
101
    }
  }
}

template <typename scalar_t>
void rotary_embedding_gptj_impl(
102
103
104
105
106
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
                                            // [num_tokens]
    scalar_t* __restrict__ query,           /// [batch_size, seq_len, num_heads,
                                   /// head_size] or [num_tokens, num_heads,
                                   /// head_size]
107
108
    scalar_t* __restrict__ key,  // nullptr (optional) or
                                 // [batch_size, seq_len, num_kv_heads,
109
110
111
112
                                 // head_size] or [num_tokens, num_kv_heads,
                                 // head_size]
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
                                                 // 2]
113
114
115
116
117
118
119
120
121
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
    const int num_heads, const int num_kv_heads, const int head_size,
    const int num_tokens) {
  const int embed_dim = rot_dim / 2;

#pragma omp parallel for collapse(2)
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    for (int i = 0; i < num_heads; ++i) {
      int64_t pos = positions[token_idx];
122
123
124
      const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
      const scalar_t* cos_cache_ptr = cache_ptr;
      const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
125
126
127
      const int head_idx = i;
      const int64_t token_head =
          token_idx * query_stride + head_idx * head_size;
128
      scalar_t* head_query = token_head + query;
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
      for (int j = 0; j < embed_dim; j += 1) {
        const int rot_offset = j;
        const int x_index = 2 * rot_offset;
        const int y_index = 2 * rot_offset + 1;

        const float cos = cos_cache_ptr[rot_offset];
        const float sin = sin_cache_ptr[rot_offset];

        const float x = head_query[x_index];
        const float y = head_query[y_index];

        head_query[x_index] = x * cos - y * sin;
        head_query[y_index] = y * cos + x * sin;
      }
    }
  }

146
147
148
149
  if (key == nullptr) {
    return;
  }

150
151
152
153
#pragma omp parallel for collapse(2)
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
    for (int i = 0; i < num_kv_heads; ++i) {
      int64_t pos = positions[token_idx];
154
155
156
      const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
      const scalar_t* cos_cache_ptr = cache_ptr;
      const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
157
158
      const int head_idx = i;
      const int64_t token_head = token_idx * key_stride + head_idx * head_size;
159
      scalar_t* head_key = key + token_head;
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
      for (int j = 0; j < embed_dim; j += 1) {
        const int rot_offset = j;
        const int x_index = 2 * rot_offset;
        const int y_index = 2 * rot_offset + 1;

        const float cos = cos_cache_ptr[rot_offset];
        const float sin = sin_cache_ptr[rot_offset];

        const float x = head_key[x_index];
        const float y = head_key[y_index];

        head_key[x_index] = x * cos - y * sin;
        head_key[y_index] = y * cos + x * sin;
      }
    }
  }
}
177
};  // namespace
178

179
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
180
                      std::optional<torch::Tensor> key, int64_t head_size,
181
                      torch::Tensor& cos_sin_cache, bool is_neox) {
182
  int num_tokens = positions.numel();
183
184
  int rot_dim = cos_sin_cache.size(1);
  int num_heads = query.size(-1) / head_size;
185
186
  int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
  int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
187
188
189
190
191
192
193
194
  int64_t query_stride = query.stride(-2);

  VLLM_DISPATCH_FLOATING_TYPES(
      query.scalar_type(), "rotary_embedding_impl", [&] {
        CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
        if (is_neox) {
          rotary_embedding_impl(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
195
196
197
              key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
              cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size, num_tokens);
198
199
200
        } else {
          rotary_embedding_gptj_impl(
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
201
202
203
              key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
              cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
              key_stride, num_heads, num_kv_heads, head_size, num_tokens);
204
205
206
207
208
        }

        CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
      });
}