cache.h 2.29 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

#include <map>
#include <vector>

void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
                 const torch::Tensor &block_mapping);

// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const &key_caches,
                 std::vector<torch::Tensor> const &value_caches,
                 const torch::Tensor &block_mapping);

void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
                       torch::Tensor &key_cache, torch::Tensor &value_cache,
                       torch::Tensor &slot_mapping,
                       const std::string &kv_cache_dtype, const double k_scale,
                       const double v_scale, const bool asm_layout);

void reshape_and_cache_flash(torch::Tensor &key, torch::Tensor &value,
                             torch::Tensor &key_cache,
                             torch::Tensor &value_cache,
                             torch::Tensor &slot_mapping,
                             const std::string &kv_cache_dtype,
                             torch::Tensor& k_scale, torch::Tensor& v_scale);

void reshape_and_cache_with_pertoken_quant(torch::Tensor &key, torch::Tensor &value,
                                           torch::Tensor &key_cache, torch::Tensor &value_cache,
                                           torch::Tensor &k_dequant_scales, torch::Tensor &v_dequant_scales,
                                           torch::Tensor &slot_mapping,
                                           const bool asm_layout);

void reshape_and_cache_with_block_quant(torch::Tensor &key, torch::Tensor &value,
                                        torch::Tensor &key_cache, torch::Tensor &value_cache,
                                        torch::Tensor &k_dequant_scales, torch::Tensor &v_dequant_scales,
                                        torch::Tensor &slot_mapping,
                                        const bool asm_layout);

// Just for unittest
void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache,
                 const double scale, const std::string &kv_cache_dtype);