transfer.cu 23.9 KB
Newer Older
1
2
3
4
5
6
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <c10/util/irange.h>

#include <cstdint>

7
#ifndef USE_ROCM
maxiao1's avatar
maxiao1 committed
8
#define WARP_SIZE 64
9
#include "pytorch_extension_utils.h"
10
11
12
13
#else
#include "pytorch_extension_utils_rocm.h"
#include "utils.h"  // WARP_SIZE
#endif
14
15
16

__device__ __forceinline__ void
transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) {
17
18
19
20
  const uint64_t* __restrict__ src = static_cast<const uint64_t*>(src_addr);
  uint64_t* __restrict__ dst = static_cast<uint64_t*>(dst_addr);
  const int total_chunks = item_size_bytes / sizeof(uint64_t);

21
#pragma unroll
22
23
24
25
26
27
28
29
30
31
  for (int j = lane_id; j < total_chunks; j += WARP_SIZE) {
#ifndef USE_ROCM
    uint64_t tmp;
    asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory");
    asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory");

#else
    uint64_t tmp = __builtin_nontemporal_load(src + j);
    __builtin_nontemporal_store(tmp, dst + j);
#endif
32
33
34
  }
}

35
36
37
38
39
40
41
42
template <typename T>
__device__ __forceinline__ T* get_global_offset_lf(
    T* base,
    const uintptr_t* __restrict__ /*unused*/,
    int64_t layer_id,
    int64_t layer_dim,
    int64_t page_id,
    int64_t item_size_bytes) {
43
  // layer first
44
  return base + layer_id * layer_dim + page_id * item_size_bytes;
45
46
}

47
48
49
50
51
52
53
54
template <typename T>
__device__ __forceinline__ T* get_global_offset_pf(
    T* base,
    const uintptr_t* __restrict__ /*unused*/,
    int64_t layer_id,
    int64_t page_dim,
    int64_t page_id,
    int64_t item_size_bytes) {
55
  // page first
56
57
58
59
60
61
62
63
64
65
66
67
68
  return base + page_id * page_dim + layer_id * item_size_bytes;
}

// get offset from layer base table when layers are not contiguous
template <typename T>
__device__ __forceinline__ T* get_global_offset_lf_tbl(
    T* /*unused*/,
    const uintptr_t* __restrict__ layer_base_tbl,
    int64_t layer_id,
    int64_t /*unused*/,
    int64_t page_id,
    int64_t item_size_bytes) {
  return reinterpret_cast<T*>(layer_base_tbl[layer_id]) + page_id * item_size_bytes;
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
}

template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
__global__ void transfer_kernel_impl(
    const void* __restrict__ src_k,
    void* __restrict__ dst_k,
    const void* __restrict__ src_v,
    void* __restrict__ dst_v,
    const int64_t* __restrict__ src_indices,
    const int64_t* __restrict__ dst_indices,
    int64_t start_layer_id,
    int64_t num_layers_to_process,
    int64_t num_items,
    int64_t items_per_warp,
    int64_t item_size_bytes,
    int64_t src_layout_dim,
85
86
87
88
89
    int64_t dst_layout_dim,
    const uintptr_t* __restrict__ src_k_layer_tbl,
    const uintptr_t* __restrict__ dst_k_layer_tbl,
    const uintptr_t* __restrict__ src_v_layer_tbl,
    const uintptr_t* __restrict__ dst_v_layer_tbl) {
90
  int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
91
92
  int32_t lane_id = tid % WARP_SIZE;
  int32_t warp_id = tid / WARP_SIZE;
93
94

  for (int i = 0; i < items_per_warp; ++i) {
95
    int64_t item_id = warp_id * items_per_warp + i;
96
    if (item_id >= num_items) {
97
      break;
98
99
100
101
102
103
    }
    const int64_t src_page_id = src_indices[item_id];
    const int64_t dst_page_id = dst_indices[item_id];

    // Loop over layers if necessary
    for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) {
104
105
106
107
108
      const char* src_ptr = SrcOffsetFn(
          static_cast<const char*>(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes);
      char* dst_ptr = DstOffsetFn(
          static_cast<char*>(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
      transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes);
109

110
111
112
113
114
115
      if constexpr (!IsMLA) {
        const char* src_v_ptr = SrcOffsetFn(
            static_cast<const char*>(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes);
        char* dst_v_ptr = DstOffsetFn(
            static_cast<char*>(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes);
        transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes);
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
      }
    }
  }
}

template <auto SrcOffsetFn, auto DstOffsetFn, bool IsMLA>
void transfer_kv_launcher(
    const at::Tensor& src_k,
    at::Tensor& dst_k,
    const at::Tensor& src_v,
    at::Tensor& dst_v,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t start_layer_id,
    int64_t num_layers_to_process,
    int64_t item_size,
    int64_t src_layout_dim,
    int64_t dst_layout_dim,
134
135
136
137
    const at::Tensor& src_k_layers,
    const at::Tensor& dst_k_layers,
    const at::Tensor& src_v_layers,
    const at::Tensor& dst_v_layers,
138
139
140
141
142
143
144
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor");
  TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor");
  TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long");
  TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long");
  TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
145
  TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8");
146

147
  auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; };
148
149
150
151
  const int64_t num_items = src_indices.numel();
  const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block);
  const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block);
  dim3 grid_dim(num_blocks, 1, 1);
152
  const int32_t threads_per_block = num_warps_per_block * WARP_SIZE;
153

154
155
156
157
158
159
160
161
162
  const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr;
  void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr;
  const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr();
  void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr();
  const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr<uintptr_t>() : nullptr;
  const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr<uintptr_t>() : nullptr;
  const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr<uintptr_t>();
  const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr<uintptr_t>();

163
164
  cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
  transfer_kernel_impl<SrcOffsetFn, DstOffsetFn, IsMLA><<<grid_dim, threads_per_block, 0, torch_current_stream>>>(
165
166
167
168
      src_k_ptr,
      dst_k_ptr,
      src_v_ptr,
      dst_v_ptr,
169
170
171
172
173
174
      src_indices.data_ptr<int64_t>(),
      dst_indices.data_ptr<int64_t>(),
      start_layer_id,
      num_layers_to_process,
      num_items,
      items_per_warp,
175
176
177
178
179
180
181
      item_size,
      src_layout_dim,
      dst_layout_dim,
      src_k_tbl_ptr,
      dst_k_tbl_ptr,
      src_v_tbl_ptr,
      dst_v_tbl_ptr);
182
183
184
185
186
187
188
189
190
191
192
193
194
  C10_CUDA_KERNEL_LAUNCH_CHECK();
}

void transfer_kv_per_layer(
    const at::Tensor src_k,
    at::Tensor dst_k,
    const at::Tensor src_v,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t block_quota,
    int64_t num_warps_per_block) {
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf<const char>, get_global_offset_lf<char>, false>(
      src_k,
      dst_k,
      src_v,
      dst_v,
      src_indices,
      dst_indices,
      0,
      1,
      item_size,
      0,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
214
215
}

216
void transfer_kv_per_layer_pf_lf(
217
218
219
220
221
222
    const at::Tensor src_k,
    at::Tensor dst_k,
    const at::Tensor src_v,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
223
    int64_t layer_id,
224
    int64_t item_size,
225
    int64_t src_layout_dim,
226
227
    int64_t block_quota,
    int64_t num_warps_per_block) {
228
229
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_pf<const char>, get_global_offset_lf<char>, false>(
230
231
232
233
234
235
      src_k,
      dst_k,
      src_v,
      dst_v,
      src_indices,
      dst_indices,
236
      layer_id,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
      1,
      item_size,
      src_layout_dim,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer(
    const at::Tensor src_k_layers,
    const at::Tensor dst_k_layers,
    const at::Tensor src_v_layers,
    const at::Tensor dst_v_layers,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_lf_tbl<char>, false>(
      empty,
      empty,
      empty,
      empty,
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
      0,
      0,
      src_k_layers,
      dst_k_layers,
      src_v_layers,
      dst_v_layers,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer_lf_pf(
    const at::Tensor src_k_layers,
    at::Tensor dst_k,
    const at::Tensor src_v_layers,
    at::Tensor dst_v,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t dst_layout_dim,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_pf<char>, false>(
      empty,
      dst_k,
      empty,
      dst_v,
      src_indices,
      dst_indices,
      0,
304
305
      num_layers,
      item_size,
306
307
308
309
310
311
      0,
      dst_layout_dim,
      src_k_layers,
      empty,
      src_v_layers,
      empty,
312
313
314
315
316
317
318
319
320
321
322
323
      block_quota,
      num_warps_per_block);
}

void transfer_kv_per_layer_mla(
    const at::Tensor src,
    at::Tensor dst,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t block_quota,
    int64_t num_warps_per_block) {
324
325
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf<const char>, get_global_offset_lf<char>, true>(
326
327
      src,
      dst,
328
329
      empty,
      empty,
330
331
332
333
334
335
336
      src_indices,
      dst_indices,
      0,
      1,
      item_size,
      0,
      0,
337
338
339
340
      empty,
      empty,
      empty,
      empty,
341
342
343
344
      block_quota,
      num_warps_per_block);
}

345
void transfer_kv_per_layer_mla_pf_lf(
346
347
348
349
    const at::Tensor src,
    at::Tensor dst,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
350
    int64_t layer_id,
351
    int64_t item_size,
352
    int64_t src_layout_dim,
353
354
    int64_t block_quota,
    int64_t num_warps_per_block) {
355
356
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_pf<const char>, get_global_offset_lf<char>, true>(
357
358
      src,
      dst,
359
360
361
362
      empty,
      empty,
      src_indices,
      dst_indices,
363
      layer_id,
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
      1,
      item_size,
      src_layout_dim,
      0,
      empty,
      empty,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer_mla(
    const at::Tensor src_layers,
    const at::Tensor dst_layers,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_lf_tbl<char>, true>(
      empty,
      empty,
      empty,
      empty,
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
      0,
      0,
      src_layers,
      dst_layers,
      empty,
      empty,
      block_quota,
      num_warps_per_block);
}

void transfer_kv_all_layer_mla_lf_pf(
    const at::Tensor src_layers,
    at::Tensor dst,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t item_size,
    int64_t dst_layout_dim,
    int64_t num_layers,
    int64_t block_quota,
    int64_t num_warps_per_block) {
  TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers");
  at::Tensor empty;
  transfer_kv_launcher<get_global_offset_lf_tbl<const char>, get_global_offset_pf<char>, true>(
      empty,
      dst,
      empty,
      empty,
424
425
426
427
428
      src_indices,
      dst_indices,
      0,
      num_layers,
      item_size,
429
430
431
432
433
434
      0,
      dst_layout_dim,
      src_layers,
      empty,
      empty,
      empty,
435
436
437
438
439
      block_quota,
      num_warps_per_block);
}

inline void transfer_page_direct(
440
441
    const at::Tensor src_buffer,
    at::Tensor dst_buffer,
442
443
444
445
446
447
448
449
450
    int64_t src_page_index,
    int64_t dst_page_index,
    int64_t page_size) {
  dst_buffer.slice(0, dst_page_index, dst_page_index + page_size)
      .copy_(
          src_buffer.slice(0, src_page_index, src_page_index + page_size),
          /* non_blocking= */ true);
}

451
452
453
454
455
456
457
458
void transfer_kv_direct(
    const std::vector<at::Tensor>& src_layers,
    std::vector<at::Tensor> dst_layers,
    const at::Tensor src_indices,
    const at::Tensor dst_indices,
    int64_t page_size) {
  TORCH_CHECK(
      src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers");
459
460
461
462
463
464
465
  TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
  TORCH_CHECK(page_size > 0, "Page size must be positive");
  TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");

  auto src_indices_cpu = src_indices.cpu();
  auto dst_indices_cpu = dst_indices.cpu();

466
  const auto num_indices = src_indices_cpu.numel();
467
  const int64_t num_layers = src_layers.size();
468
469
  int64_t* src_indices_ptr = src_indices_cpu.data_ptr<int64_t>();
  int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr<int64_t>();
470

471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
  int64_t start_index = 0;
  int64_t end_index = 0;

  for (int64_t i = 0; i < num_indices; ++i) {
    if (i < num_indices - 1) {
      auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i];
      auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i];

      if (src_diff == 1 && dst_diff == 1) {
        continue;
      }
      end_index = i + 1;
    } else {  // last batch
      end_index = num_indices;
    }
    auto src_index = src_indices_ptr[start_index];
    auto dst_index = dst_indices_ptr[start_index];
    auto num_tokens = end_index - start_index;
489

490
    for (int64_t j = 0; j < num_layers; ++j) {
491
      transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens);
492
    }
493
    start_index = end_index;
494
495
  }
}
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

template <bool IsLf2Pf>
inline void transfer_kv_page_first_direct_impl(
    const std::vector<at::Tensor>& src_ptrs,
    std::vector<at::Tensor> dst_ptrs,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t start_layer_id,
    int64_t page_size) {
  TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length");
  TORCH_CHECK(page_size > 0, "Page size must be positive");
  TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size");

  auto src_indices_cpu = src_indices.cpu();
  auto dst_indices_cpu = dst_indices.cpu();
  const int64_t num_pages = src_indices_cpu.size(0) / page_size;

  if constexpr (IsLf2Pf) {
    const bool is_mla = dst_ptrs.size() == 1;
    const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2;

    for (const auto i : c10::irange(num_pages)) {
      auto s_index = src_indices_cpu[i * page_size].item<int64_t>();
      auto d_index = dst_indices_cpu[i * page_size].item<int64_t>() / page_size;
      for (int64_t j = 0; j < num_layers; ++j) {
        transfer_page_direct(
            src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size);
        if (!is_mla) {
          transfer_page_direct(
              src_ptrs[j + num_layers],
              dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j),
              s_index,
              0,
              page_size);
        }
      }
    }
  } else {
    const bool is_mla = src_ptrs.size() == 1;
    const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2;

    for (const auto i : c10::irange(num_pages)) {
      auto s_index = src_indices_cpu[i * page_size].item<int64_t>() / page_size;
      auto d_index = dst_indices_cpu[i * page_size].item<int64_t>();
      for (int64_t j = 0; j < num_layers; ++j) {
        transfer_page_direct(
            src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size);
        if (!is_mla) {
          transfer_page_direct(
              src_ptrs[1].select(0, s_index).select(0, start_layer_id + j),
              dst_ptrs[j + num_layers],
              0,
              d_index,
              page_size);
        }
      }
    }
  }
}

void transfer_kv_per_layer_direct_pf_lf(
    const std::vector<at::Tensor>& src_ptrs,
    std::vector<at::Tensor> dst_ptrs,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t layer_id,
    int64_t page_size) {
  transfer_kv_page_first_direct_impl<false>(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size);
}

void transfer_kv_all_layer_direct_lf_pf(
    const std::vector<at::Tensor>& src_ptrs,
    std::vector<at::Tensor> dst_ptrs,
    const at::Tensor& src_indices,
    const at::Tensor& dst_indices,
    int64_t page_size) {
  transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
}
liucong's avatar
liucong committed
574
575
576
577
578

__device__ int64_t ceil_div(int64_t a, int64_t b) {
    return (a + b - 1) / b;
}

liucong's avatar
liucong committed
579
580
581
582
__device__ int64_t safe_min(int64_t a, int64_t b) {
    return a < b ? a : b;
}

liucong's avatar
liucong committed
583
584
585
586
587
588
__global__ void launch_alloc_decode_kernel(
    const int64_t* seq_lens_ptr,   
    const int32_t* last_loc_ptr,    
    const int64_t* free_page_ptr,   
    int64_t* out_indices,     
    int64_t bs_upper,            
liucong's avatar
liucong committed
589
    int64_t page_size) {
liucong's avatar
liucong committed
590
591
592
593
594
595
596
597
598
599
600

  int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;

  if (pid >= bs_upper) return;
  
  int64_t seq_len = seq_lens_ptr[pid];
  int64_t pre_len = seq_len - 1;
  
  int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
  
  int64_t sum_num_new_pages = 0;
liucong's avatar
liucong committed
601
  for (int64_t i = 0; i <= pid; i++) {
liucong's avatar
liucong committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
      int64_t other_seq_len = seq_lens_ptr[i];
      int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
      
      int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
      int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
      int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
      
      sum_num_new_pages += other_num_new_pages;
  }
  int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;

  if (num_page_start_loc_self == 0) {
      int32_t last_loc = last_loc_ptr[pid];
      out_indices[pid] = last_loc + 1;
  } else {
      int64_t page = free_page_ptr[new_page_start_loc];
      out_indices[pid] = page * page_size;
  }
}

liucong's avatar
liucong committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
__global__ void launch_alloc_extend_kernel(
    const int64_t* pre_lens_ptr,
    const int64_t* seq_lens_ptr,
    const int64_t* last_loc_ptr,
    const int64_t* free_page_ptr,
    int64_t* out_indices,
    int64_t bs_upper,
    int64_t page_size,
    int64_t max_num_extend_tokens) 
{
    int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (pid >= bs_upper) return;
    
    int64_t seq_len = seq_lens_ptr[pid];
    int64_t pre_len = pre_lens_ptr[pid];
    int64_t extend_len = seq_len - pre_len;
    
    int64_t sum_extend_lens = 0;
    for (int64_t i = 0; i <= pid; i++) {
      int64_t other_seq_len = seq_lens_ptr[i];
      int64_t other_pre_len = pre_lens_ptr[i];
      int64_t other_extend_len = other_seq_len - other_pre_len;
      sum_extend_lens += other_extend_len;
    }
    
    int64_t output_start_loc = sum_extend_lens - extend_len;
    int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
    
    int64_t sum_num_new_pages = 0;
    for (int64_t i = 0; i <= pid; i++) {
      int64_t other_seq_len = seq_lens_ptr[i];
      int64_t other_pre_len = pre_lens_ptr[i];
      
      int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
      int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
      int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
      
      sum_num_new_pages += other_num_new_pages;
    }
    int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
    
    int64_t last_loc = last_loc_ptr[pid];
    int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;

    for (int64_t offset = 0; offset < num_part1; offset++) {
        int64_t output_idx = output_start_loc + offset;
        out_indices[output_idx] = last_loc + 1 + offset;
    }
    
    if (pre_len + num_part1 == seq_len) {
        return;
    }
    
    int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
    for (int64_t offset = 0; offset < num_part2; offset++) {
      int64_t page_idx = new_page_start_loc + offset / page_size;
      int64_t page_start = free_page_ptr[page_idx];
      int64_t output_idx = output_start_loc + num_part1 + offset;
      out_indices[output_idx] = page_start * page_size + offset % page_size;
    }

    if (pre_len + num_part1 + num_part2 == seq_len) {
        return;
    }
    
    int64_t num_part3 = seq_len - (seq_len / page_size) * page_size;
    int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1;
    int64_t start_loc = free_page_ptr[last_page_idx];

    for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) {
      int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset;
      out_indices[output_idx] = start_loc * page_size + offset;
    }
}

liucong's avatar
liucong committed
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
void dcu_alloc_decode_kernel(
  const at::Tensor seq_lens_ptr,   
  const at::Tensor last_loc_ptr,    
  const at::Tensor free_page_ptr,   
  at::Tensor out_indices, 
  int64_t bs,          
  int64_t bs_upper,              
  int64_t page_size) {

    const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
    const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
    const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
    int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());

    int64_t block_size = 64;
    int64_t grid_size = (bs + block_size - 1) / block_size;
    cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
    launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}
liucong's avatar
liucong committed
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741

void dcu_alloc_extend_kernel(
    const at::Tensor pre_lens_ptr,
    const at::Tensor seq_lens_ptr,
    const at::Tensor last_loc_ptr,
    const at::Tensor free_page_ptr,
    at::Tensor out_indices,
    int64_t bs,
    int64_t bs_upper,
    int64_t page_size,
    int64_t max_num_extend_tokens) {

      const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
      const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
      const int64_t* last_loc_ptr1 = static_cast<const int64_t*>(last_loc_ptr.data_ptr());
      const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
      int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());

      int64_t block_size = 64;
      int64_t grid_size = (bs + block_size - 1) / block_size;
      cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
      launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size, max_num_extend_tokens);
      C10_CUDA_KERNEL_LAUNCH_CHECK();
}