cache.h 2.97 KB
Newer Older
1
2
#pragma once

3
#include <torch/all.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
6
7
#include <map>
#include <vector>

8
9
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
                 const torch::Tensor& block_mapping);
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
13
14
15
// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
                 std::vector<torch::Tensor> const& value_caches,
16
                 const torch::Tensor& block_mapping);
17

18
19
20
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
                     const torch::Tensor& block_mapping);

21
22
23
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                       torch::Tensor& key_cache, torch::Tensor& value_cache,
                       torch::Tensor& slot_mapping,
24
25
                       const std::string& kv_cache_dtype,
                       torch::Tensor& k_scale, torch::Tensor& v_scale);
Woosuk Kwon's avatar
Woosuk Kwon committed
26

zhuwenwen's avatar
zhuwenwen committed
27
28
29
30
31
32
void reshape_and_cache_cuda(torch::Tensor& key, torch::Tensor& value,
                            torch::Tensor& key_cache, torch::Tensor& value_cache,
                            torch::Tensor& slot_mapping,
                            const std::string& kv_cache_dtype,
                            torch::Tensor& k_scale, torch::Tensor& v_scale); 

33
34
35
36
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
                             torch::Tensor& key_cache,
                             torch::Tensor& value_cache,
                             torch::Tensor& slot_mapping,
37
                             const std::string& kv_cache_dtype,
38
                             torch::Tensor& k_scale, torch::Tensor& v_scale);
39

40
41
42
43
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
                          torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
                          const std::string& kv_cache_dtype,
                          torch::Tensor& scale);
44

45
// Just for unittest
46
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
47
                 const double scale, const std::string& kv_cache_dtype);
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

void read_cache(
    torch::Tensor& keys,
    torch::Tensor& values,
    std::vector<torch::Tensor> const& key_caches,
    std::vector<torch::Tensor> const& value_caches,
    torch::Tensor& slot_mapping,
    const std::string& kv_cache_dtype);

void write_cache_multi_layers(
    torch::Tensor& keys,
    torch::Tensor& values,
    std::vector<torch::Tensor> const& key_caches,
    std::vector<torch::Tensor> const& value_caches,
    torch::Tensor& slot_mapping,
    const std::string& kv_cache_dtype);
zhuwenwen's avatar
zhuwenwen committed
64

65
66
67
68
69
void gather_cache(
    torch::Tensor const& src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
    torch::Tensor const& dst,          // [TOT_TOKENS, ENTRIES...]
    torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES]
    torch::Tensor const& cu_seq_lens,  // [BATCH+1]
zhuwenwen's avatar
zhuwenwen committed
70
    int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);