transfer.cu 45.6 KB
Newer Older
1
2
3
4
5
6
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/irange.h>

#include <cstdint>

7
#ifndef USE_ROCM
maxiao1's avatar
maxiao1 committed
8
#define WARP_SIZE 64
9
#include "pytorch_extension_utils.h"
10
11
12
13
#else
#include "pytorch_extension_utils_rocm.h"
#include "utils.h"  // WARP_SIZE
#endif
14
15
16

__device__ __forceinline__ void
transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) {
17
18
19
20
  const uint64_t* __restrict__ src = static_cast<const uint64_t*>(src_addr);
  uint64_t* __restrict__ dst = static_cast<uint64_t*>(dst_addr);
  const int total_chunks = item_size_bytes / sizeof(uint64_t);

21
#pragma unroll
22
23
24
25
26
27
28
29
30
31
  for (int j = lane_id; j < total_chunks; j += WARP_SIZE) {
#ifndef USE_ROCM
    uint64_t tmp;
    asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory");
    asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory");

#else
    uint64_t tmp = __builtin_nontemporal_load(src + j);
    __builtin_nontemporal_store(tmp, dst + j);
#endif
32
33
34
  }
}

35
36
37
38
39
40
41
42
template <typename T>
__device__ __forceinline__ T* get_global_offset_lf(
    T* base,
    const uintptr_t* __restrict__ /*unused*/,
    int64_t layer_id,
    int64_t layer_dim,
    int64_t page_id,
    int64_t item_size_bytes) {
43
  // layer first
44
  return base + layer_id * layer_dim + page_id * item_size_bytes;
45
46
}

47
48
49
50
51
52
53
54
template <typename T>
__device__ __forceinline__ T* get_global_offset_pf(
    T* base,
    const uintptr_t* __restrict__ /*unused*/,
    int64_t layer_id,
    int64_t page_dim,
    int64_t page_id,
    int64_t item_size_bytes) {
55
  // page first
56
57
58
59
60
61
62
63
64
65
66
67
68
  return base + page_id * page_dim + layer_id * item_size_bytes;
}

// get offset from layer base table when layers are not contiguous
template <typename T>
__device__ __forceinline__ T* get_global_offset_lf_tbl(
    T* /*unused*/,
    const uintptr_t* __restrict__ layer_base_tbl,
    int64_t layer_id,
    int64_t /*unused*/,
    int64_t page_id,
    int64_t item_size_bytes) {
  return reinterpret_cast<T*>(layer_base_tbl[layer_id]) + page_id * item_size_bytes;
69
70
}

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
template <typename T>
__device__ __forceinline__ T* get_global_offset_per_head_lf(
    T* base,
    const uintptr_t* __restrict__ /*unused*/,
    int64_t layer_id,
    int64_t layer_dim,
    int64_t page_id,
    int64_t item_size_bytes,
    int64_t head_id,
    int64_t head_num,
    int64_t /*unused*/) {
  // layer first offset func per head
  return base + layer_id * layer_dim + page_id * item_size_bytes + item_size_bytes / head_num * head_id;
}

template <typename T>
__device__ __forceinline__ T* get_global_offset_per_head_lf_tbl(
    T* /*unused*/,
    const uintptr_t* __restrict__ layer_base_tbl,
    int64_t layer_id,
    int64_t /*unused*/,
    int64_t page_id,
    int64_t item_size_bytes,
    int64_t head_id,
    int64_t head_num,
    int64_t /*unused*/) {
  return reinterpret_cast<T*>(layer_base_tbl[layer_id]) + page_id * item_size_bytes +
         item_size_bytes / head_num * head_id;
}

template <typename T>
__device__ __forceinline__ T* get_global_offset_ph(
    T* base,
    const uintptr_t* __restrict__ /*unused*/,
    int64_t layer_id,
    int64_t page_dim,
    int64_t page_id,
    int64_t item_size_bytes,
    int64_t head_id,
    int64_t head_num,
    int64_t page_size) {
  // page head layout: [page_num, head_num, page_size, layer_num, head_dim]
  return base + page_id / page_size * page_size * page_dim +  // page_num dimension offset
         page_dim / head_num * head_id * page_size +          // head_num dimension offset
         page_id % page_size * page_dim / head_num +          // page_size dimension offset
         layer_id * item_size_bytes / head_num;               // layer_num dimension offset
}

template <auto SrcOffsetFn, auto DstOffsetFn>
__global__ void transfer_page_head_kernel_impl(
    const void* __restrict__ src_k,
    void* __restrict__ dst_k,
    const void* __restrict__ src_v,
    void* __restrict__ dst_v,
    const int64_t* __restrict__ src_indices,
    const int64_t* __restrict__ dst_indices,
    int64_t start_layer_id,
    int64_t num_layers_to_process,
    int64_t num_items,
    int64_t items_per_warp,
    int64_t item_size_bytes,
    int64_t src_layout_dim,
    int64_t dst_layout_dim,
    const uintptr_t* __restrict__ src_k_layer_tbl,
    const uintptr_t* __restrict__ dst_k_layer_tbl,
    const uintptr_t* __restrict__ src_v_layer_tbl,
    const uintptr_t* __restrict__ dst_v_layer_tbl,
    const int64_t page_size,
    const int64_t head_num) {
  int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
  int32_t lane_id = tid % WARP_SIZE;
  int32_t warp_id = tid / WARP_SIZE;
  const int64_t head_size_bytes = item_size_bytes / head_num;

  for (int i = 0; i < items_per_warp; ++i) {
    int64_t item_id = warp_id * items_per_warp + i;
    if (item_id >= num_items) {
      break;
    }
    const int64_t src_page_id = src_indices[item_id];
    const int64_t dst_page_id = dst_indices[item_id];

    // Loop over layers if necessary
    for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) {
      // For page head layout, the cache of each head in the token is discontinuous, need to loop
      for (int64_t head_id = 0; head_id < head_num; ++head_id) {
        const char* src_k_ptr = SrcOffsetFn(
            static_cast<const char*>(src_k),
            src_k_layer_tbl,
            layer_id,
            src_layout_dim,
            src_page_id,
            item_size_bytes,
            head_id,
            head_num,
            page_size);
        char* dst_k_ptr = DstOffsetFn(
            static_cast<char*>(dst_k),
            dst_k_layer_tbl,
            layer_id,
            dst_layout_dim,
            dst_page_id,
            item_size_bytes,
            head_id,
            head_num,
            page_size);
        transfer_item_warp(lane_id, src_k_ptr, dst_k_ptr, head_size_bytes);

        const char* src_v_ptr = SrcOffsetFn(
            static_cast<const char*>(src_v),
            src_v_layer_tbl,
            layer_id,
            src_layout_dim,
            src_page_id,
            item_size_bytes,
            head_id,
            head_num,
            page_size);
        char* dst_v_ptr = DstOffsetFn(
            static_cast<char*>(dst_v),
            dst_v_layer_tbl,
            layer_id,
            dst_layout_dim,
            dst_page_id,
            item_size_bytes,
            head_id,
            head_num,
            page_size);
        transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, head_size_bytes);
      }
    }
  }
}

205
206
207
208
209
210
211
212
213
214
215
216
217
218
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
__global__ void transfer_kernel_impl(
    const void* __restrict__ src_k,
    void* __restrict__ dst_k,
    const void* __restrict__ src_v,
    void* __restrict__ dst_v,
    const int64_t* __restrict__ src_indices,
    const int64_t* __restrict__ dst_indices,
    int64_t start_layer_id,
    int64_t num_layers_to_process,
    int64_t num_items,
    int64_t items_per_warp,
    int64_t item_size_bytes,
    int64_t src_layout_dim,
219
220
221
222
223
    int64_t dst_layout_dim,
    const uintptr_t* __restrict__ src_k_layer_tbl,
    const uintptr_t* __restrict__ dst_k_layer_tbl,
    const uintptr_t* __restrict__ src_v_layer_tbl,
    const uintptr_t* __restrict__ dst_v_layer_tbl) {
224
  int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
225
226
  int32_t lane_id = tid % WARP_SIZE;
  int32_t warp_id = tid / WARP_SIZE;
227
228

  for (int i = 0; i < items_per_warp; ++i) {
229
    int64_t item_id = warp_id * items_per_warp + i;
230
    if (item_id >= num_items) {
231
      break;
232
233
234
235
236
237
    }
    const int64_t src_page_id = src_indices[item_id];
    const int64_t dst_page_id = dst_indices[item_id];

    // Loop over layers if necessary
    for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) {
238
239
240
241
242
      const char* src_ptr = SrcOffsetFn(
          static_cast<const char*>(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes);
      char* dst_ptr = DstOffsetFn(
          static_cast<char*>(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
      transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes);
243

244
245
246
247
248
249
      if constexpr (!IsMLA) {
        const char* src_v_ptr = SrcOffsetFn(
            static_cast<const char*>(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes);
        char* dst_v_ptr = DstOffsetFn(
            static_cast<char*>(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
        transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes);
250
251
252
253
254
      }
    }
  }
}

255
template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA, bool PageHeadLayout = false>
256
257
258
259
260
261
262
263
264
265
266
267
void transfer_kv_launcher(
    const at::Tensor& src_k,
    at::Tensor& dst_k,
    const at::Tensor& src_v,
    at::Tensor& dst_v,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t start_layer_id,
    int64_t num_layers_to_process,
    int64_t item_size,
    int64_t src_layout_dim,
    int64_t dst_layout_dim,
268
269
270
271
    const at::Tensor& src_k_layers,
    const at::Tensor& dst_k_layers,
    const at::Tensor& src_v_layers,
    const at::Tensor& dst_v_layers,
272
    int64_t block_quota,
273
274
275
    int64_t num_warps_per_block,
    const int64_t page_size = 16,
    const int64_t head_num = 1) {
276
277
278
279
280
  TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor");
  TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor");
  TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long");
  TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long");
  TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
281
  TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8");
282

283
  auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; };
284
285
286
287
  const int64_t num_items = src_indices.numel();
  const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block);
  const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block);
  dim3 grid_dim(num_blocks, 1, 1);
288
  const int32_t threads_per_block = num_warps_per_block * WARP_SIZE;
289

290
291
292
293
294
295
296
297
298
  const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr;
  void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr;
  const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr();
  void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr();
  const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr<uintptr_t>() : nullptr;
  const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr<uintptr_t>() : nullptr;
  const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr<uintptr_t>();
  const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr<uintptr_t>();

299
  cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
  if constexpr (PageHeadLayout) {
    transfer_page_head_kernel_impl<SrcOffsetFn, DstOffsetFn><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
        src_k_ptr,
        dst_k_ptr,
        src_v_ptr,
        dst_v_ptr,
        src_indices.data_ptr<int64_t>(),
        dst_indices.data_ptr<int64_t>(),
        start_layer_id,
        num_layers_to_process,
        num_items,
        items_per_warp,
        item_size,
        src_layout_dim,
        dst_layout_dim,
        src_k_tbl_ptr,
        dst_k_tbl_ptr,
        src_v_tbl_ptr,
        dst_v_tbl_ptr,
        page_size,
        head_num);
  } else {
    transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
        src_k_ptr,
        dst_k_ptr,
        src_v_ptr,
        dst_v_ptr,
        src_indices.data_ptr<int64_t>(),
        dst_indices.data_ptr<int64_t>(),
        start_layer_id,
        num_layers_to_process,
        num_items,
        items_per_warp,
        item_size,
        src_layout_dim,
        dst_layout_dim,
        src_k_tbl_ptr,
        dst_k_tbl_ptr,
        src_v_tbl_ptr,
        dst_v_tbl_ptr);
  }
341
342
343
344
345
346
347
348
349
350
351
352
353
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

void transfer_kv_per_layer(
    const at::Tensor src_k,
    at::Tensor dst_k,
    const at::Tensor src_v,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t block_quota,
    int64_t num_warps_per_block) {
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf<const char>, get_global_offset_lf<char>, false>(
      src_k,
      dst_k,
      src_v,
      dst_v,
      src_indices,
      dst_indices,
      0,
      1,
      item_size,
      0,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
373
374
}

375
void transfer_kv_per_layer_pf_lf(
376
377
378
379
380
381
    const at::Tensor src_k,
    at::Tensor dst_k,
    const at::Tensor src_v,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
382
    int64_t layer_id,
383
    int64_t item_size,
384
    int64_t src_layout_dim,
385
386
    int64_t block_quota,
    int64_t num_warps_per_block) {
387
388
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_pf<const char>, get_global_offset_lf<char>, false>(
389
390
391
392
393
394
      src_k,
      dst_k,
      src_v,
      dst_v,
      src_indices,
      dst_indices,
395
      layer_id,
396
397
398
399
400
401
402
403
404
405
406
407
      1,
      item_size,
      src_layout_dim,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
}

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
void transfer_kv_per_layer_ph_lf(
    const at::Tensor src_k,
    at::Tensor dst_k,
    const at::Tensor src_v,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t layer_id,
    int64_t item_size,
    int64_t src_layout_dim,
    int64_t page_size,
    int64_t head_num,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_ph<const char>, get_global_offset_per_head_lf<char>, false, true>(
      src_k,
      dst_k,
      src_v,
      dst_v,
      src_indices,
      dst_indices,
      layer_id,
      1,
      item_size,
      src_layout_dim,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block,
      page_size,
      head_num);
}

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
void transfer_kv_all_layer(
    const at::Tensor src_k_layers,
    const at::Tensor dst_k_layers,
    const at::Tensor src_v_layers,
    const at::Tensor dst_v_layers,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_lf_tbl<char>, false>(
      empty,
      empty,
      empty,
      empty,
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
      0,
      0,
      src_k_layers,
      dst_k_layers,
      src_v_layers,
      dst_v_layers,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer_lf_pf(
    const at::Tensor src_k_layers,
    at::Tensor dst_k,
    const at::Tensor src_v_layers,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t dst_layout_dim,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_pf<char>, false>(
      empty,
      dst_k,
      empty,
      dst_v,
      src_indices,
      dst_indices,
      0,
500
501
      num_layers,
      item_size,
502
503
504
505
506
507
      0,
      dst_layout_dim,
      src_k_layers,
      empty,
      src_v_layers,
      empty,
508
509
510
511
      block_quota,
      num_warps_per_block);
}

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
void transfer_kv_all_layer_lf_ph(
    const at::Tensor src_k_layers,
    at::Tensor dst_k,
    const at::Tensor src_v_layers,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t dst_layout_dim,
    int64_t num_layers,
    int64_t page_size,
    int64_t head_num,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_per_head_lf_tbl<const char>, get_global_offset_ph<char>, false, true>(
      empty,
      dst_k,
      empty,
      dst_v,
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
      0,
      dst_layout_dim,
      src_k_layers,
      empty,
      src_v_layers,
      empty,
      block_quota,
      num_warps_per_block,
      page_size,
      head_num);
}

550
551
552
553
554
555
556
557
void transfer_kv_per_layer_mla(
    const at::Tensor src,
    at::Tensor dst,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t block_quota,
    int64_t num_warps_per_block) {
558
559
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf<const char>, get_global_offset_lf<char>, true>(
560
561
      src,
      dst,
562
563
      empty,
      empty,
564
565
566
567
568
569
570
      src_indices,
      dst_indices,
      0,
      1,
      item_size,
      0,
      0,
571
572
573
574
      empty,
      empty,
      empty,
      empty,
575
576
577
578
      block_quota,
      num_warps_per_block);
}

579
void transfer_kv_per_layer_mla_pf_lf(
580
581
582
583
    const at::Tensor src,
    at::Tensor dst,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
584
    int64_t layer_id,
585
    int64_t item_size,
586
    int64_t src_layout_dim,
587
588
    int64_t block_quota,
    int64_t num_warps_per_block) {
589
590
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_pf<const char>, get_global_offset_lf<char>, true>(
591
592
      src,
      dst,
593
594
595
596
      empty,
      empty,
      src_indices,
      dst_indices,
597
      layer_id,
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
      1,
      item_size,
      src_layout_dim,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer_mla(
    const at::Tensor src_layers,
    const at::Tensor dst_layers,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_lf_tbl<char>, true>(
      empty,
      empty,
      empty,
      empty,
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
      0,
      0,
      src_layers,
      dst_layers,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer_mla_lf_pf(
    const at::Tensor src_layers,
    at::Tensor dst,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t dst_layout_dim,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_pf<char>, true>(
      empty,
      dst,
      empty,
      empty,
658
659
660
661
662
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
663
664
665
666
667
668
      0,
      dst_layout_dim,
      src_layers,
      empty,
      empty,
      empty,
669
670
671
672
673
      block_quota,
      num_warps_per_block);
}

inline void transfer_page_direct(
674
675
    const at::Tensor src_buffer,
    at::Tensor dst_buffer,
676
677
678
679
680
681
682
683
684
    int64_t src_page_index,
    int64_t dst_page_index,
    int64_t page_size) {
  dst_buffer.slice(0, dst_page_index, dst_page_index + page_size)
      .copy_(
          src_buffer.slice(0, src_page_index, src_page_index + page_size),
          /* non_blocking= */ true);
}

685
686
687
688
689
690
691
692
void transfer_kv_direct(
    const std::vector<at::Tensor>& src_layers,
    std::vector<at::Tensor> dst_layers,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t page_size) {
  TORCH_CHECK(
      src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers");
693
694
695
696
697
698
699
  TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
  TORCH_CHECK(page_size > 0, "Page size must be positive");
  TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");

  auto src_indices_cpu = src_indices.cpu();
  auto dst_indices_cpu = dst_indices.cpu();

700
  const auto num_indices = src_indices_cpu.numel();
701
  const int64_t num_layers = src_layers.size();
702
703
  int64_t* src_indices_ptr = src_indices_cpu.data_ptr<int64_t>();
  int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr<int64_t>();
704

705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
  int64_t start_index = 0;
  int64_t end_index = 0;

  for (int64_t i = 0; i < num_indices; ++i) {
    if (i < num_indices - 1) {
      auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i];
      auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i];

      if (src_diff == 1 && dst_diff == 1) {
        continue;
      }
      end_index = i + 1;
    } else {  // last batch
      end_index = num_indices;
    }
    auto src_index = src_indices_ptr[start_index];
    auto dst_index = dst_indices_ptr[start_index];
    auto num_tokens = end_index - start_index;
723

724
    for (int64_t j = 0; j < num_layers; ++j) {
725
      transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens);
726
    }
727
    start_index = end_index;
728
729
  }
}
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807

template <bool IsLf2Pf>
inline void transfer_kv_page_first_direct_impl(
    const std::vector<at::Tensor>& src_ptrs,
    std::vector<at::Tensor> dst_ptrs,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t start_layer_id,
    int64_t page_size) {
  TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
  TORCH_CHECK(page_size > 0, "Page size must be positive");
  TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");

  auto src_indices_cpu = src_indices.cpu();
  auto dst_indices_cpu = dst_indices.cpu();
  const int64_t num_pages = src_indices_cpu.size(0) / page_size;

  if constexpr (IsLf2Pf) {
    const bool is_mla = dst_ptrs.size() == 1;
    const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2;

    for (const auto i : c10::irange(num_pages)) {
      auto s_index = src_indices_cpu[i * page_size].item<int64_t>();
      auto d_index = dst_indices_cpu[i * page_size].item<int64_t>() / page_size;
      for (int64_t j = 0; j < num_layers; ++j) {
        transfer_page_direct(
            src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size);
        if (!is_mla) {
          transfer_page_direct(
              src_ptrs[j + num_layers],
              dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j),
              s_index,
              0,
              page_size);
        }
      }
    }
  } else {
    const bool is_mla = src_ptrs.size() == 1;
    const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2;

    for (const auto i : c10::irange(num_pages)) {
      auto s_index = src_indices_cpu[i * page_size].item<int64_t>() / page_size;
      auto d_index = dst_indices_cpu[i * page_size].item<int64_t>();
      for (int64_t j = 0; j < num_layers; ++j) {
        transfer_page_direct(
            src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size);
        if (!is_mla) {
          transfer_page_direct(
              src_ptrs[1].select(0, s_index).select(0, start_layer_id + j),
              dst_ptrs[j + num_layers],
              0,
              d_index,
              page_size);
        }
      }
    }
  }
}

void transfer_kv_per_layer_direct_pf_lf(
    const std::vector<at::Tensor>& src_ptrs,
    std::vector<at::Tensor> dst_ptrs,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t layer_id,
    int64_t page_size) {
  transfer_kv_page_first_direct_impl<false>(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size);
}

void transfer_kv_all_layer_direct_lf_pf(
    const std::vector<at::Tensor>& src_ptrs,
    std::vector<at::Tensor> dst_ptrs,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t page_size) {
  transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
}
liucong8560's avatar
liucong8560 committed
808
809
810
811
812
813
814
815
816
817
818
819
820
821

__device__ int64_t ceil_div(int64_t a, int64_t b) {
    return (a + b - 1) / b;
}

__device__ int64_t safe_min(int64_t a, int64_t b) {
    return a < b ? a : b;
}

__global__ void launch_alloc_decode_kernel(
    const int64_t* seq_lens_ptr,   
    const int32_t* last_loc_ptr,    
    const int64_t* free_page_ptr,   
    int64_t* out_indices,     
maxiao1's avatar
maxiao1 committed
822
    int64_t bs,            
liucong8560's avatar
liucong8560 committed
823
824
825
826
    int64_t page_size) {

  int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;

maxiao1's avatar
maxiao1 committed
827
  if (pid >= bs) return;
liucong8560's avatar
liucong8560 committed
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
  
  int64_t seq_len = seq_lens_ptr[pid];
  int64_t pre_len = seq_len - 1;
  
  int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
  
  int64_t sum_num_new_pages = 0;
  for (int64_t i = 0; i <= pid; i++) {
      int64_t other_seq_len = seq_lens_ptr[i];
      int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
      
      int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
      int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
      int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
      
      sum_num_new_pages += other_num_new_pages;
  }
  int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;

  if (num_page_start_loc_self == 0) {
      int32_t last_loc = last_loc_ptr[pid];
      out_indices[pid] = last_loc + 1;
  } else {
      int64_t page = free_page_ptr[new_page_start_loc];
      out_indices[pid] = page * page_size;
  }
}

__global__ void launch_alloc_extend_kernel(
    const int64_t* pre_lens_ptr,
    const int64_t* seq_lens_ptr,
    const int64_t* last_loc_ptr,
    const int64_t* free_page_ptr,
    int64_t* out_indices,
maxiao1's avatar
maxiao1 committed
862
863
    int64_t bs,
    int64_t page_size) 
liucong8560's avatar
liucong8560 committed
864
865
866
{
    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    
maxiao1's avatar
maxiao1 committed
867
    if (pid >= bs) return;
liucong8560's avatar
liucong8560 committed
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    
    int64_t seq_len = seq_lens_ptr[pid];
    int64_t pre_len = pre_lens_ptr[pid];
    int64_t extend_len = seq_len - pre_len;
    
    int64_t sum_extend_lens = 0;
    for (int64_t i = 0; i <= pid; i++) {
      int64_t other_seq_len = seq_lens_ptr[i];
      int64_t other_pre_len = pre_lens_ptr[i];
      int64_t other_extend_len = other_seq_len - other_pre_len;
      sum_extend_lens += other_extend_len;
    }
    
    int64_t output_start_loc = sum_extend_lens - extend_len;
    int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
    
    int64_t sum_num_new_pages = 0;
    for (int64_t i = 0; i <= pid; i++) {
      int64_t other_seq_len = seq_lens_ptr[i];
      int64_t other_pre_len = pre_lens_ptr[i];
      
      int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
      int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
      int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
      
      sum_num_new_pages += other_num_new_pages;
    }
    int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
    
    int64_t last_loc = last_loc_ptr[pid];
    int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;

    for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) {
        int64_t output_idx = output_start_loc + offset;
        out_indices[output_idx] = last_loc + 1 + offset;
    }
    
    if (pre_len + num_part1 == seq_len) {
        return;
    }
    
    int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
maxiao1's avatar
maxiao1 committed
910
    for (int64_t offset = 0; offset < num_part2; offset++) {
liucong8560's avatar
liucong8560 committed
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
      int64_t page_idx = new_page_start_loc + offset / page_size;
      int64_t page_start = free_page_ptr[page_idx];
      int64_t output_idx = output_start_loc + num_part1 + offset;
      out_indices[output_idx] = page_start * page_size + offset % page_size;
    }

    if (pre_len + num_part1 + num_part2 == seq_len) {
        return;
    }
    
    int64_t num_part3 = seq_len - (seq_len / page_size) * page_size;
    int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1;
    int64_t start_loc = free_page_ptr[last_page_idx];

    for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) {
      int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset;
      out_indices[output_idx] = start_loc * page_size + offset;
    }
}
linhai1's avatar
linhai1 committed
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
__global__ void launch_create_extend_after_decode_spec_info_int32_kernel(
    const int32_t* verified_id_ptr,
    const int64_t* seq_lens_ptr,
    const int32_t* accept_lens_ptr,
    int64_t* positions_ptr,
    int32_t* new_verified_id_ptr,
    int64_t bs) {
    
    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (pid >= bs) return;
    
    int64_t seq_length = seq_lens_ptr[pid];
    int32_t accept_length = accept_lens_ptr[pid];

    int32_t accept_len_cumsum = 0;
    for (int32_t offset = 0; offset < pid; offset++) {
        accept_len_cumsum += accept_lens_ptr[offset];
    }

    int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
    for (int32_t offset = 0; offset < accept_length && offset < bs; offset++) 
    {
      positions_ptr1[offset] = seq_length - accept_length + offset;
    }

    int32_t verified_idx = accept_len_cumsum + accept_length - 1;
    new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}

__global__ void launch_create_extend_after_decode_spec_info_int64_kernel(
    const int32_t* verified_id_ptr,
    const int64_t* seq_lens_ptr,
    const int64_t* accept_lens_ptr,
    int64_t* positions_ptr,
    int32_t* new_verified_id_ptr,
    int64_t bs) {
    
    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (pid >= bs) return;
    
    int64_t seq_length = seq_lens_ptr[pid];
    int64_t accept_length = accept_lens_ptr[pid];

    int64_t accept_len_cumsum = 0;
    for (int64_t offset = 0; offset < pid; offset++) {
        accept_len_cumsum += accept_lens_ptr[offset];
    }

    int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
    for (int64_t offset = 0; offset < accept_length && offset < bs; offset++) 
    {
      positions_ptr1[offset] = seq_length - accept_length + offset;
    }

    int64_t verified_idx = accept_len_cumsum + accept_length - 1;
    new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
liucong8560's avatar
liucong8560 committed
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005

void dcu_alloc_decode_kernel(
  const at::Tensor seq_lens_ptr,   
  const at::Tensor last_loc_ptr,    
  const at::Tensor free_page_ptr,   
  at::Tensor out_indices, 
  int64_t bs,          
  int64_t page_size) {

    const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
    const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
    const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
    int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());

    int64_t block_size = 64;
    int64_t grid_size = (bs + block_size - 1) / block_size;
    cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
maxiao1's avatar
maxiao1 committed
1006
    launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
liucong8560's avatar
liucong8560 committed
1007
1008
1009
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

linhai1's avatar
linhai1 committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
void dcu_create_extend_after_decode_spec_info(
    const at::Tensor verified_id,
    const at::Tensor seq_lens,
    const at::Tensor accept_lens,
    at::Tensor positions,
    at::Tensor new_verified_id,
    int64_t bs) {

    const int32_t* verified_id_ptr;
    const int64_t* seq_lens_ptr;
    const int32_t* accept_lens_ptr_int32;
    const int64_t* accept_lens_ptr_int64;
    int64_t* positions_ptr;
    int32_t* new_verified_id_ptr;

    int64_t block_size = 64;
    int64_t grid_size = (bs + block_size - 1) / block_size;
    cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();

    if (accept_lens.dtype() == torch::kInt32)
    {
      verified_id_ptr       = static_cast<const int32_t*>(verified_id.data_ptr());
      seq_lens_ptr          = static_cast<const int64_t*>(seq_lens.data_ptr());
      accept_lens_ptr_int32 = static_cast<const int32_t*>(accept_lens.data_ptr());
      positions_ptr         = static_cast<int64_t*>(positions.data_ptr());
      new_verified_id_ptr   = static_cast<int32_t*>(new_verified_id.data_ptr());

      launch_create_extend_after_decode_spec_info_int32_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int32, positions_ptr, new_verified_id_ptr, bs);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    }
    else
    {
      verified_id_ptr       = static_cast<const int32_t*>(verified_id.data_ptr());
      seq_lens_ptr          = static_cast<const int64_t*>(seq_lens.data_ptr());
      accept_lens_ptr_int64 = static_cast<const int64_t*>(accept_lens.data_ptr());
      positions_ptr         = static_cast<int64_t*>(positions.data_ptr());
      new_verified_id_ptr   = static_cast<int32_t*>(new_verified_id.data_ptr());

      launch_create_extend_after_decode_spec_info_int64_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int64, positions_ptr, new_verified_id_ptr, bs);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
    }
};

liucong8560's avatar
liucong8560 committed
1053
1054
1055
1056
1057
1058
1059
void dcu_alloc_extend_kernel(
    const at::Tensor pre_lens_ptr,
    const at::Tensor seq_lens_ptr,
    const at::Tensor last_loc_ptr,
    const at::Tensor free_page_ptr,
    at::Tensor out_indices,
    int64_t bs,
maxiao1's avatar
maxiao1 committed
1060
    int64_t page_size) {
liucong8560's avatar
liucong8560 committed
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070

      const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
      const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
      const int64_t* last_loc_ptr1 = static_cast<const int64_t*>(last_loc_ptr.data_ptr());
      const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
      int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());

      int64_t block_size = 64;
      int64_t grid_size = (bs + block_size - 1) / block_size;
      cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
maxiao1's avatar
maxiao1 committed
1071
      launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
liucong8560's avatar
liucong8560 committed
1072
      C10_CUDA_KERNEL_LAUNCH_CHECK();
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
}

__global__ void launch_assign_req_to_token_pool(
    const int64_t* req_pool_indices_ptr,
    int32_t* req_to_token_ptr,
    const int64_t* allocate_lens_ptr,
    int64_t* new_allocate_lens,
    int64_t* out_cache_loc_ptr,
    int64_t shape,
    int64_t bs) 
{
   
    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    if (pid >= bs) return;

    int64_t kv_start = allocate_lens_ptr[pid];
    int64_t kv_end = new_allocate_lens[pid];
    int64_t pool_idx = req_pool_indices_ptr[pid];  
    int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape);
    
    int64_t sum_out_offset = 0;
    for(int length_offset = 0; length_offset < pid;length_offset++){
        int64_t start = allocate_lens_ptr[length_offset];
        int64_t end = new_allocate_lens[length_offset];
        sum_out_offset += (end- start);
    }
    int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset;

    int64_t copy_length = kv_end - kv_start; 
    #pragma unroll(32)
    for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) {
        token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index];
    }

}


void dcu_assign_req_to_token_pool(
    const at::Tensor req_pool_indices_ptr,
    at::Tensor req_to_token_ptr,
    const at::Tensor allocate_lens_ptr,
    at::Tensor new_allocate_lens,
    at::Tensor out_cache_loc_ptr,
    int64_t shape,
    int64_t bs) {

      const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
      int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
      const int64_t* allocate_lens_ptr1 = static_cast<const int64_t*>(allocate_lens_ptr.data_ptr());
      int64_t* new_allocate_lens1 = static_cast<int64_t*>(new_allocate_lens.data_ptr());
      int64_t* out_cache_loc_ptr1 = static_cast<int64_t*>(out_cache_loc_ptr.data_ptr());

      int64_t block_size = 64;
      int64_t grid_size = (bs + block_size - 1) / block_size;
      cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
      launch_assign_req_to_token_pool<<<grid_size, block_size, 0, torch_current_stream>>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
}


__global__ void get_last_loc_kernel(
    const int32_t* __restrict__ req_to_token,
    const int64_t* __restrict__ req_pool_indices_tensor,
    const int64_t* __restrict__ prefix_lens_tensor,
    int64_t* __restrict__ result,
    int64_t num_tokens,
    int64_t req_to_token_stride){

    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    if (pid >= num_tokens) return;

    int64_t pre_len = prefix_lens_tensor[pid];
    if (pre_len > 0) {
        int64_t req_idx = req_pool_indices_tensor[pid];
        int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1);
        result[pid] = static_cast<int64_t>(req_to_token[token_idx]);
    } else {
        result[pid] = static_cast<int64_t>(-1);
    }
}

at::Tensor dcu_get_last_loc(
    const at::Tensor req_to_token,     
    const at::Tensor req_pool_indices,  
    const at::Tensor prefix_lens) {
      
    TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor");
    TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor");
    TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor");

    TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]");
    TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D");
    TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D");

    int64_t num_tokens = prefix_lens.numel();
    TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens");

    int64_t req_to_token_stride = req_to_token.stride(0);

    auto req_to_token_c = req_to_token.contiguous();
    auto req_pool_indices_c = req_pool_indices.contiguous();
    auto prefix_lens_c   = prefix_lens.contiguous();

    const int32_t* req_to_token_ptr = req_to_token_c.data_ptr<int32_t>();
    const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr<int64_t>();
    const int64_t* prefix_lens_ptr  = prefix_lens_c.data_ptr<int64_t>();

    auto result = at::empty_like(prefix_lens_c);
    int64_t* result_ptr = result.data_ptr<int64_t>();

    const int64_t block_size = 64;
    const int64_t grid_size = (num_tokens + block_size - 1) / block_size;
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    get_last_loc_kernel<<<grid_size, block_size, 0, stream>>>(
        req_to_token_ptr,
        req_pool_indices_ptr,
        prefix_lens_ptr,
        result_ptr,
        num_tokens,
        req_to_token_stride
    );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
    
    return result;
}


__global__ void launch_assign_extend_cache_locs_kernel(
    const int64_t* __restrict__ req_pool_indices,   // [bs]
    const int32_t* __restrict__ req_to_token,       // [max_num_req, pool_len]
    const int64_t* __restrict__ start_offset,       // [bs]
    const int64_t* __restrict__ end_offset,         // [bs]
    int64_t* __restrict__ out_cache_loc,            // [sum(draft_token_num)]
    int64_t pool_len,
    int64_t bs)
{
    int pid = blockIdx.x * blockDim.x + threadIdx.x;
    if (pid >= bs) return;

    int64_t kv_start = start_offset[pid];
    int64_t kv_end   = end_offset[pid];
    int64_t req_id   = req_pool_indices[pid];

    int64_t out_offset = 0;
    for (int i = 0; i < pid; ++i) {
        out_offset += end_offset[i] - start_offset[i];
    }

    const int32_t* src = req_to_token + req_id * pool_len + kv_start;
    int64_t*       dst = out_cache_loc + out_offset;
    for (int64_t i = 0; i < kv_end - kv_start; ++i) {
        dst[i] = src[i];
    }
}

void dcu_assign_extend_cache_locs(
    const at::Tensor req_pool_indices,
    const at::Tensor req_to_token,
    const at::Tensor start_offset,
    const at::Tensor end_offset,
    at::Tensor out_cache_loc,
    int64_t pool_len,
    int64_t bs)
{
    const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr<int64_t>();
    const int32_t* req_to_token_ptr     = req_to_token.data_ptr<int32_t>();
    const int64_t* start_offset_ptr     = start_offset.data_ptr<int64_t>();
    const int64_t* end_offset_ptr       = end_offset.data_ptr<int64_t>();
    int64_t* out_cache_loc_ptr          = out_cache_loc.data_ptr<int64_t>();

    constexpr int64_t threads = 128;
    int64_t blocks = (bs + threads - 1) / threads;
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();

    launch_assign_extend_cache_locs_kernel<<<blocks, threads, 0, stream>>>(
        req_pool_indices_ptr,
        req_to_token_ptr,
        start_offset_ptr,
        end_offset_ptr,
        out_cache_loc_ptr,
        pool_len,
        bs);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}


template<int PAGED_SIZE>
__global__ void dcu_create_flashmla_kv_indices_kernel(
    const int32_t* __restrict__ req_to_token,
    const int32_t* __restrict__ req_pool_indices,
    const int32_t* __restrict__ page_kernel_lens,
    const int32_t* __restrict__ kv_start_idx,
    int32_t* __restrict__ kv_indices,
    int req_to_token_stride,
    int kv_indices_stride)
{
    int pid = blockIdx.x;  // batch index

    int req_pool_index = req_pool_indices[pid];

    int kv_start = 0;
    int kv_end = 0;

    if (kv_start_idx != nullptr) {
        kv_start = kv_start_idx[pid];
        kv_end = kv_start;
    }

    kv_end += page_kernel_lens[pid];

    int total_len = kv_end - kv_start;
    int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE;

    for (int pg = 0; pg < num_pages; ++pg) {
        int offset = pg * PAGED_SIZE;

        // token id = req_to_token[req_pool_index][kv_start + offset]
        int64_t token =
            req_to_token[req_pool_index * req_to_token_stride + kv_start + offset];

        // 页索引
        kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE;
    }
}

void dcu_create_flashmla_kv_indices(
    const at::Tensor& req_to_token,
    const at::Tensor& req_pool_indices,
    const at::Tensor& page_kernel_lens,
    const c10::optional<at::Tensor>& kv_start_idx,
    at::Tensor& kv_indices,
    int64_t req_to_token_stride,
    int64_t kv_indices_stride,
    int64_t PAGED_SIZE)

{
    TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor");
    TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");

    int bs = req_pool_indices.size(0);

    auto stream = at::cuda::getCurrentCUDAStream();

    dim3 grid(bs);
    dim3 block(1);

    const int32_t* kv_start_idx_ptr = nullptr;
    if (kv_start_idx.has_value()) {
        kv_start_idx_ptr = kv_start_idx.value().data_ptr<int32_t>();
    }
    if (PAGED_SIZE == 64) {
        dcu_create_flashmla_kv_indices_kernel<64><<<grid, block, 0, stream>>>(
            req_to_token.data_ptr<int32_t>(),
            req_pool_indices.data_ptr<int32_t>(),
            page_kernel_lens.data_ptr<int32_t>(),
            kv_start_idx_ptr,
            kv_indices.data_ptr<int32_t>(),
            req_to_token_stride,
            kv_indices_stride
        );
    } else {
        TORCH_CHECK(false, "Unsupported PAGED_SIZE");
    }
}



__global__ void launch_create_chunked_prefix_cache_kv_indices(
    int32_t* req_to_token_ptr,
    const int64_t* req_pool_indices_ptr,
    const int32_t* chunk_starts_ptr,
    const int32_t* chunk_seq_lens_ptr,
    const int32_t* chunk_cu_seq_lens_ptr,
    int32_t* chunk_kv_indices_ptr,
    int64_t col_num,
    int64_t bs) 
{
   
    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    if (pid >= bs) return;

    int64_t req_pool_index = req_pool_indices_ptr[pid];
    int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid];

    int32_t chunk_start_pos = chunk_starts_ptr[pid];
    int32_t chunk_seq_len = chunk_seq_lens_ptr[pid];
    #pragma unroll(32)
    for(int32_t offset = 0;offset < chunk_seq_len;offset++){
          chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset];
    }
   
}


void dcu_create_chunked_prefix_cache_kv_indices(
    at::Tensor req_to_token_ptr,
    const at::Tensor req_pool_indices_ptr,
    const at::Tensor chunk_starts_ptr,
    const at::Tensor chunk_seq_lens_ptr,
    const at::Tensor chunk_cu_seq_lens_ptr,
    at::Tensor chunk_kv_indices_ptr,
    int64_t col_num,
    int64_t bs) {
    
    int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
    const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
    const int32_t* chunk_starts_ptr1 = static_cast<const int32_t*>(chunk_starts_ptr.data_ptr());
    const int32_t* chunk_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_seq_lens_ptr.data_ptr());
    const int32_t* chunk_cu_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_cu_seq_lens_ptr.data_ptr());
    int32_t* chunk_kv_indices_ptr1 = static_cast<int32_t*>(chunk_kv_indices_ptr.data_ptr());

    int64_t block_size = 64;
    int64_t grid_size = (bs + block_size - 1) / block_size;
    cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
    launch_create_chunked_prefix_cache_kv_indices<<<grid_size, block_size, 0, torch_current_stream>>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs);
    C10_CUDA_KERNEL_LAUNCH_CHECK();

}