mla_decode.cpp 14.8 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
#include "cpu_types.hpp"
#include <float.h>

namespace {
template <typename scalar_t>
struct KernelVecType {
  using qk_load_vec_type = void;
  using qk_vec_type = void;
  using v_load_vec_type = void;
};

template <>
struct KernelVecType<float> {
  using qk_load_vec_type = vec_op::FP32Vec16;
  using qk_vec_type = vec_op::FP32Vec16;
  using v_load_vec_type = vec_op::FP32Vec16;
};

template <>
struct KernelVecType<c10::Half> {
#if defined(__powerpc64__)
  // Power specific vector types
  using qk_load_vec_type = vec_op::FP32Vec16;
  using qk_vec_type = vec_op::FP32Vec16;
  using v_load_vec_type = vec_op::FP32Vec16;
#else
  // Fallback for other architectures, including x86
  using qk_load_vec_type = vec_op::FP16Vec16;
  using qk_vec_type = vec_op::FP32Vec16;
  using v_load_vec_type = vec_op::FP16Vec16;
#endif
};

#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {
  using qk_load_vec_type = vec_op::BF16Vec32;
  using qk_vec_type = vec_op::BF16Vec32;
  using v_load_vec_type = vec_op::BF16Vec16;
};
#else
template <>
struct KernelVecType<c10::BFloat16> {
  using qk_load_vec_type = vec_op::BF16Vec16;
  using qk_vec_type = vec_op::FP32Vec16;
  using v_load_vec_type = vec_op::BF16Vec16;
};
#endif

template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
          typename qk_vec_type>
void mla_decode_block_head(
    const qk_vec_type* __restrict__ q_vecs,          // [HEAD_UNROLL, head_dim]
    const qk_vec_type* __restrict__ k_vecs,          // [block_size, head_dim]
    const vec_op::FP32Vec16* __restrict v_vecs_f32,  // [block_size, v_head_dim]
    float* __restrict__ acc_out,  // [HEAD_UNROLL, v_head_dim]
    float* __restrict__ acc_lse,  // [HEAD_UNROLL]
    const float scale, const int num_tokens) {
  using f32_vec_type = vec_op::FP32Vec16;
  constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
  constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;

  float logits[BLOCK_SIZE][HEAD_UNROLL] = {};  // initialize to zeros
  float max_val[HEAD_UNROLL];
  std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);

  f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
  for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
    // load to registers
    qk_vec_type q_vec[HEAD_UNROLL];

#pragma unroll
    for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
      q_vec[unroll] =
          qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};

    for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
      qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);

#pragma unroll
      for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
        vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
    }
  }

  for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
#pragma unroll
    for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
      const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
      logits[block_offset][unroll] = acc;
      max_val[unroll] = std::max(max_val[unroll], acc);
    }
  }

  float sum_exp[HEAD_UNROLL] = {};
  for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
#pragma unroll
    for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
      const float val =
          std::exp(logits[block_offset][unroll] - max_val[unroll]);
      logits[block_offset][unroll] = val;
      sum_exp[unroll] += val;
    }
  }

  f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];

  for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
    // load to registers
    f32_vec_type scale_[HEAD_UNROLL];

#pragma unroll
    for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
      scale_[unroll] =
          f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};

    for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
      f32_vec_type v_vec(
          v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);

#pragma unroll
      for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
        vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
    }
  }

  // merge attention state
  // section 2.2 in https://arxiv.org/pdf/2501.01005
  f32_vec_type prev_scale[HEAD_UNROLL];
  f32_vec_type curr_scale[HEAD_UNROLL];

#pragma unroll
  for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
    const float prev_lse = acc_lse[unroll];
    const float curr_lse = std::log(sum_exp[unroll]) +
                           max_val[unroll];  // add back max_val to get true lse
    // softmax trick
    const float max_lse = std::max(prev_lse, curr_lse);
    const float prev_sum_exp = std::exp(prev_lse - max_lse);
    const float curr_sum_exp = std::exp(curr_lse - max_lse);

    const float new_sum_exp = prev_sum_exp + curr_sum_exp;
    acc_lse[unroll] = std::log(new_sum_exp) + max_lse;

    prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
    curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
  }

  for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
#pragma unroll
    for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
      f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
      o_vec = o_vec * prev_scale[unroll] +
              this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
      o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
    }
  }

  q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
  acc_out += V_HEAD_DIM * HEAD_UNROLL;
}

template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
          typename qk_vec_type>
void mla_decode_block(
    const qk_vec_type* __restrict__ q_vecs,  // [num_heads, head_dim]
    const scalar_t* __restrict__ kv_cache,   // [block_size, head_dim]
    float* __restrict__ acc_out,             // [num_heads, v_head_dim]
    float* __restrict__ acc_lse,             // [num_heads]
    const int num_heads, const float scale, const int num_tokens) {
  using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
  static_assert(
      std::is_same<qk_vec_type,
                   typename KernelVecType<scalar_t>::qk_vec_type>::value);
  using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
  using f32_vec_type = vec_op::FP32Vec16;
  static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
  static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
  constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
  constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;

  const qk_vec_type* k_vecs;
  const f32_vec_type* v_vecs_f32;
  float* kv_cache_f32 = nullptr;

  if constexpr (!std::is_same<scalar_t, float>::value) {
    // convert KV cache block to FP32 to reuse it across query heads and
    // attn @ V computation, since FP16/BF16->FP32 is expensive.
    // TODO: move malloc outside of this fn to reuse across iterations.
    const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
    kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));

    for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
      for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
        v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
        f32_vec_type kv_vec_f32(kv_load_vec);
        kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
      }

    if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
      // for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
      // NOTE: in this case, we only need to convert the V section to FP32.
      // But for simplicity, we will convert the whole KV block to FP32.
      k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
    } else {
      k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
    }

    // attn @ V always use FP32 for V, since attn is FP32.
    v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);

  } else {
    // KV cache is FP32. don't need to do anything.
    k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
    v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
  }

  // compute 2 heads at the same time to improve ILP and
  // take advantage of register cache for K and V.
  constexpr int HEAD_UNROLL = 2;
  for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
    mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
        q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);

    q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
    acc_out += HEAD_UNROLL * V_HEAD_DIM;
    acc_lse += HEAD_UNROLL;
  }

  // take care of the remaining heads
  for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
    mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
        q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);

    q_vecs += HEAD_DIM / QK_NUM_ELEM;
    acc_out += V_HEAD_DIM;
    acc_lse += 1;
  }

  if (kv_cache_f32 != nullptr) {
    std::free(kv_cache_f32);
  }
}
}  // namespace

template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
void mla_decode_kvcache_cpu_impl(
    scalar_t* __restrict__ out,             // [num_seqs, num_heads, v_head_dim]
    const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_dim]
    const scalar_t* __restrict__ kv_cache,  // [num_blocks, block_size,
                                            // head_dim]
    const int num_heads, const float scale,
    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
    const int* __restrict__ seq_lens,      // [num_seqs]
    const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
    const int kv_stride, const int num_seqs) {
  using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
  using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
  constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;

  // shared across threads
  const int max_threads = omp_get_max_threads();
  const int acc_out_nbytes =
      max_threads * num_heads * V_HEAD_DIM * sizeof(float);
  float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
  std::vector<float> acc_lse(max_threads * num_heads);

  // allocate memory to pre-convert query to FP32 later
  float* q_f32;
  constexpr bool PRE_CONVERT_QUERY =
      !std::is_same<scalar_t, float>::value &&
      std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
  if constexpr (PRE_CONVERT_QUERY) {
    const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
    q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
  }

#pragma omp parallel
  {
    const int num_threads = omp_get_num_threads();
    const int thread_id = omp_get_thread_num();
    float* __restrict__ acc_out_thread =
        acc_out + thread_id * num_heads * V_HEAD_DIM;
    float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;

    for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
      // reset accumulator
      std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
      std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);

      const int seq_len = seq_lens[seq_idx];
      const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
      const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;

      const qk_vec_type* q_vecs;
      if constexpr (PRE_CONVERT_QUERY) {
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
#pragma omp for
        for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
          qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
          qk_vec_type q_vec(q_load_vec);
          q_vec.save(q_f32 + i);
        }
        q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
      } else {
        q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
      }

#pragma omp for
      for (int block_idx = 0; block_idx < block_num; ++block_idx) {
        const int physical_block_idx =
            block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
        const int num_tokens =
            block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;

        mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
            q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
            acc_lse_thread, num_heads, scale, num_tokens);
      }

// merge attention states across threads
// section 2.2 in https://arxiv.org/pdf/2501.01005
// each thread is responsible for 1 head
#pragma omp for
      for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
        float* acc_lse_head = acc_lse.data() + head_idx;
        float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;

        float max_val = -FLT_MAX;
        for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
          max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
        }

        float sum_exp = 0.0f;
        for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
          float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
          acc_lse_head[thread_id_ * num_heads] = val;
          sum_exp += val;
        }

        float inv_sum = 1.0f / sum_exp;
        float out_head[V_HEAD_DIM] = {};
        for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
          float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
          for (int i = 0; i < V_HEAD_DIM; ++i) {
            out_head[i] +=
                acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
          }
        }

        for (int i = 0; i < V_HEAD_DIM; ++i) {
          vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
                                             head_idx * V_HEAD_DIM + i);
        }
      }
    }
  }
  if (PRE_CONVERT_QUERY) {
    std::free(q_f32);
  }
  std::free(acc_out);
}

void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
                        torch::Tensor& kv_cache, double scale,
                        torch::Tensor& block_tables, torch::Tensor& seq_lens) {
  const int num_seqs = query.size(0);
  const int num_heads = query.size(1);
  const int head_dim = query.size(2);
  const int block_size = kv_cache.size(1);
  const int v_head_dim = out.size(2);

  const int max_num_blocks_per_seq = block_tables.size(1);
  const int o_stride = out.stride(0);
  const int q_stride = query.stride(0);
  const int kv_stride = kv_cache.stride(0);

  VLLM_DISPATCH_FLOATING_TYPES(
      query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
        CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
        if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
          mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
              out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
              kv_cache.data_ptr<scalar_t>(), num_heads, scale,
              block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
              max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
        else
          TORCH_CHECK(false, "Unsupported block size: ", block_size);
        CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
      });
}