cache_kernels.cu 18.6 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
3
#include <c10/cuda/CUDAGuard.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
#include "cuda_compat.h"
6
#include "dispatch_utils.h"
7
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
#include <algorithm>
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
#include <cassert>
#include <map>
12
#include <vector>
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
18
19
20
21
  torch::Tensor& src,
  torch::Tensor& dst,
  const std::map<int64_t, int64_t>& block_mapping) {
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda()) {
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
    TORCH_CHECK(
      src_device.index() == dst_device.index(),
      "src and dst must be on the same GPU");
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
30
    memcpy_type = cudaMemcpyDeviceToDevice;
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
    memcpy_type = cudaMemcpyDeviceToHost;
  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
    memcpy_type = cudaMemcpyHostToDevice;
  } else {
Woosuk Kwon's avatar
Woosuk Kwon committed
31
    TORCH_CHECK(false, "Invalid device combination");
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
  }

34
35
  char *src_ptr = static_cast<char*>(src.data_ptr());
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
38
  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
Woosuk Kwon's avatar
Woosuk Kwon committed
39
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Woosuk Kwon's avatar
Woosuk Kwon committed
40
  // NOTE(woosuk): This can be slow if the number of blocks is large.
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
44
45
46
47
48
49
50
51
52
53
  for (const auto& pair : block_mapping) {
    int64_t src_block_number = pair.first;
    int64_t dst_block_number = pair.second;
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(
      dst_ptr + dst_offset,
      src_ptr + src_offset,
      block_size_in_bytes,
      memcpy_type,
      stream);
  }
}
Woosuk Kwon's avatar
Woosuk Kwon committed
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
namespace vllm {
56
57
58
59
60
61

// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
  int64_t* key_cache_ptrs,
  int64_t* value_cache_ptrs,
62
  const int64_t* __restrict__ block_mapping,
63
64
65
66
67
68
  const int numel_per_block) {
  const int layer_idx = blockIdx.x;
  const int pair_idx = blockIdx.y;

  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
69
70
  int64_t src_block_number = block_mapping[2 * pair_idx];
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
71

72
73
  const int64_t src_block_offset = src_block_number * numel_per_block;
  const int64_t dst_block_offset = dst_block_number * numel_per_block;
74
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
75
76
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
77
78
79
    key_cache[dst_offset] = key_cache[src_offset];
  }
  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
80
81
    int64_t src_offset = src_block_offset + i;
    int64_t dst_offset = dst_block_offset + i;
82
83
84
85
    value_cache[dst_offset] = value_cache[src_offset];
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
86
} // namespace vllm
87

88
void copy_blocks(
89
90
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
91
  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
92
93
94
95
96
97
98
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0) {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());
99

100
101
102
103
104
105
106
107
108
  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }
  // Create block mapping array.
109
  std::vector<int64_t> block_mapping_vec;
110
  for (const auto& pair : block_mapping) {
111
112
    int64_t src_block_number = pair.first;
    for (int64_t dst_block_number : pair.second) {
113
114
      block_mapping_vec.push_back(src_block_number);
      block_mapping_vec.push_back(dst_block_number);
115
116
    }
  }
117
  int64_t* block_mapping_array = block_mapping_vec.data();
118
119
120
121
122
123
124
125
126
  int num_pairs = block_mapping_vec.size() / 2;

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
  torch::Tensor block_mapping_tensor = torch::from_blob(
127
    block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
128
129
130
131
132

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
133
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
134
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
135
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
136
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
137
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
138
139
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
140
        block_mapping_tensor.data_ptr<int64_t>(),
141
142
        numel_per_block);
    }));
143
144
}

Woosuk Kwon's avatar
Woosuk Kwon committed
145
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
146

147
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
Woosuk Kwon's avatar
Woosuk Kwon committed
148
__global__ void reshape_and_cache_kernel(
149
150
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
151
152
  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
153
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
154
155
  const int key_stride,
  const int value_stride,
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
158
159
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
160
161
  const int64_t token_idx = blockIdx.x;
  const int64_t slot_idx = slot_mapping[token_idx];
162
163
164
165
166
  if (slot_idx < 0) {
    // Padding token that should be ignored.
    return;
  }

167
168
  const int64_t block_idx = slot_idx / block_size;
  const int64_t block_offset = slot_idx % block_size;
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171

  const int n = num_heads * head_size;
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
172
173
    const int64_t src_key_idx = token_idx * key_stride + i;
    const int64_t src_value_idx = token_idx * value_stride + i;
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
179

    const int head_idx = i / head_size;
    const int head_offset = i % head_size;
    const int x_idx = head_offset / x;
    const int x_offset = head_offset % x;

180
181
182
183
184
185
186
187
188
    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                + head_idx * (head_size / x) * block_size * x
                                + x_idx * block_size * x
                                + block_offset * x
                                + x_offset;
    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
                                  + head_idx * head_size * block_size
                                  + head_offset * block_size
                                  + block_offset;
189
190
191
192
193
194
195
196
197
198
199
200
201
    scalar_t tgt_key = key[src_key_idx];
    scalar_t tgt_value = value[src_value_idx];
    if constexpr (is_fp8_e5m2_kv_cache) {
#ifdef ENABLE_FP8_E5M2
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#else
      assert(false);
#endif
    } else {
      key_cache[tgt_key_idx] = tgt_key;
      value_cache[tgt_value_idx] = tgt_value;
    }
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
204
  }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
205
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE)                                \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
    slot_mapping.data_ptr<int64_t>(),                                                              \
    key_stride,                                                                                    \
    value_stride,                                                                                  \
    num_heads,                                                                                     \
    head_size,                                                                                     \
    block_size,                                                                                    \
    x);

Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223
224
225
void reshape_and_cache(
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
226
227
  torch::Tensor& slot_mapping,  // [num_tokens]
  const std::string& kv_cache_dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
230
231
232
233
234
235
236
237
238
239
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
240
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
Woosuk Kwon's avatar
Woosuk Kwon committed
241
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
  if (kv_cache_dtype == "auto") {
    if (key.dtype() == at::ScalarType::Float) {
      CALL_RESHAPE_AND_CACHE(float, float, false);
    } else if (key.dtype() == at::ScalarType::Half) {
      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
    } else if (key.dtype() == at::ScalarType::BFloat16) {
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
    }
  } else if (kv_cache_dtype == "fp8_e5m2") {
    if (key.dtype() == at::ScalarType::Float) {
      CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
    } else if (key.dtype() == at::ScalarType::Half) {
      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
    } else if (key.dtype() == at::ScalarType::BFloat16) {
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
    }
  } else {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
  }
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
}

Woosuk Kwon's avatar
Woosuk Kwon committed
263
namespace vllm {
Woosuk Kwon's avatar
Woosuk Kwon committed
264

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
  scalar_t* __restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
  scalar_t* __restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
  const scalar_t* __restrict__ key_cache,   // [num_blocks, num_heads, head_size/x, block_size, x]
  const scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
  const int* __restrict__ slot_mapping,   // [num_tokens]
  const int key_stride,
  const int value_stride,
  const int num_heads,
  const int head_size,
  const int block_size,
  const int x) {
    const int token_idx = blockIdx.x;
    const int slot_idx = slot_mapping[token_idx];
    const int block_idx = slot_idx / block_size;
    const int block_offset = slot_idx % block_size;

    const int num_tokens = num_heads * head_size;
    for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
      const int tgt_key_idx = token_idx * key_stride + i;
      const int tgt_value_idx = token_idx * value_stride + i;
288

289
290
291
292
      const int head_idx = i / head_size;
      const int head_offset = i % head_size;
      const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension
      const int x_offset = head_offset % x;
293

294
295
296
297
298
299
300
301
302
303
      const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                              + head_idx * (head_size / x) * block_size * x
                              + x_idx * block_size * x
                              + block_offset * x
                              + x_offset;
      const int src_value_idx = block_idx * num_heads * head_size * block_size
                                + head_idx * head_size * block_size
                                + head_offset * block_size
                                + block_offset;

304
305
      key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
      value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    }
}

template <typename scalar_t>
__global__ void gather_cached_kv_kernel_optimized(
    scalar_t *__restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
    scalar_t *__restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
    const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
    const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
    const int *__restrict__ slot_mapping,   // [num_tokens]
    const int key_stride,
    const int value_stride,
    const int num_heads,
    const int head_size,
    const int block_size,
    const int x)
{
    const int token_idx = blockIdx.x;
    const int slot_idx = slot_mapping[token_idx];
    const int block_idx = slot_idx / block_size;
    const int block_offset = slot_idx % block_size;

    const int dim = num_heads * head_size;
    assert(dim % 4 == 0);  // this is true for known use cases
    const int unroll_factor = 4;
    const int unrolled_dim = dim / unroll_factor;

    for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
    {
        int tgt_key_indices[unroll_factor];
        int tgt_value_indices[unroll_factor];
        int src_key_indices[unroll_factor];
        int src_value_indices[unroll_factor];
        scalar_t keys_to_store[unroll_factor];
        scalar_t values_to_store[unroll_factor];

        #pragma unroll
        for (int j = 0; j < unroll_factor; ++j)
        {
            int index = i + j * unrolled_dim;

            const int tgt_key_idx = token_idx * key_stride + index;
            const int tgt_value_idx = token_idx * value_stride + index;

            const int head_idx = index / head_size;
            const int head_offset = index % head_size;
            const int x_idx = head_offset / x;
            const int x_offset = head_offset % x;

            const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                                    + head_idx * (head_size / x) * block_size * x
                                    + x_idx * block_size * x
                                    + block_offset * x
                                    + x_offset;
            const int src_value_idx = block_idx * num_heads * head_size * block_size
                                      + head_idx * head_size * block_size
                                      + head_offset * block_size
                                      + block_offset;

            tgt_key_indices[j] = tgt_key_idx;
            tgt_value_indices[j] = tgt_value_idx;
            src_key_indices[j] = src_key_idx;
            src_value_indices[j] = src_value_idx;

370
371
            keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
            values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
372
373
374
375
376
377
378
379
380
381
382
        }

        #pragma unroll
        for (int j = 0; j < unroll_factor; ++j)
        {
            key[tgt_key_indices[j]] = keys_to_store[j];
            value[tgt_value_indices[j]] = values_to_store[j];
        }
    }
}

Woosuk Kwon's avatar
Woosuk Kwon committed
383
} // namespace vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
384

385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
void gather_cached_kv(
  torch::Tensor& key,           // [out] [num_tokens, num_heads, head_size]
  torch::Tensor& value,         // [out] [num_tokens, num_heads, head_size]
  torch::Tensor& key_cache,     // [in]  [num_blocks, num_heads, head_size/x, block_size, x]
  torch::Tensor& value_cache,   // [in]  [num_blocks, num_heads, head_size, block_size]
  torch::Tensor& slot_mapping)  // [in]  [num_tokens]
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
403
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
404
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
405
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
406
407
408
    key.scalar_type(),
    "gather_cached_kv_kernel_optimized",
    [&] {
Woosuk Kwon's avatar
Woosuk Kwon committed
409
      vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
410
411
412
413
414
415
416
417
418
419
420
421
422
        key.data_ptr<scalar_t>(),
        value.data_ptr<scalar_t>(),
        key_cache.data_ptr<scalar_t>(),
        value_cache.data_ptr<scalar_t>(),
        slot_mapping.data_ptr<int>(),
        key_stride,
        value_stride,
        num_heads,
        head_size,
        block_size,
        x);
    });
}
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

namespace vllm {

template<typename Tout, typename Tin>
__global__ void convert_fp8_e5m2_kernel(
  const Tin* __restrict__ src_cache,
  Tout* __restrict__ dst_cache,
  const int64_t block_stride) {
  const int64_t block_idx = blockIdx.x;
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
    int64_t idx = block_idx * block_stride + i;
#ifdef ENABLE_FP8_E5M2
    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
    assert(false);
#endif
  }
}

} // namespace vllm

#define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \
  vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
    reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \
    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \
    block_stride);

void convert_fp8_e5m2(
  torch::Tensor& src_cache,
  torch::Tensor& dst_cache)
{
  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (src_cache.dtype() == at::ScalarType::Float) {
    CALL_CONVERT_FP8_E5M2(uint8_t, float);
  } else if (src_cache.dtype() == at::ScalarType::Half) {
    CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
    CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
  } else if (dst_cache.dtype() == at::ScalarType::Float) {
    CALL_CONVERT_FP8_E5M2(float, uint8_t);
  } else if (dst_cache.dtype() == at::ScalarType::Half) {
    CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
    CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
  }
}