cache.h 4.2 KB
Newer Older
1
2
#pragma once

3
#include <torch/all.h>
4
#include <c10/util/Optional.h>
Woosuk Kwon's avatar
Woosuk Kwon committed
5

6
7
8
#include <map>
#include <vector>

9
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
10
                 int64_t block_size_in_bytes,
11
                 const torch::Tensor& block_mapping);
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
14
15
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
                       torch::Tensor& key_cache, torch::Tensor& value_cache,
                       torch::Tensor& slot_mapping,
16
17
                       const std::string& kv_cache_dtype,
                       torch::Tensor& k_scale, torch::Tensor& v_scale);
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
20
21
22
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
                             torch::Tensor& key_cache,
                             torch::Tensor& value_cache,
                             torch::Tensor& slot_mapping,
23
                             const std::string& kv_cache_dtype,
24
                             torch::Tensor& k_scale, torch::Tensor& v_scale);
25

26
27
28
29
30
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
                          torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
                          const std::string& kv_cache_dtype,
                          torch::Tensor& scale);

31
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
32
33
34
35
36
// void concat_and_cache_mla_rope_fused(
//     torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
//     torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
//     torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
//     const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
37

38
// Just for unittest
39
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
40
                 const double scale, const std::string& kv_cache_dtype);
41

42
void gather_and_maybe_dequant_cache(
43
44
45
46
47
48
    torch::Tensor const& src_cache,     // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
    torch::Tensor const& dst,           // [TOT_TOKENS, ENTRIES...]
    torch::Tensor const& block_table,   // [BATCH, BLOCK_INDICES]
    torch::Tensor const& cu_seq_lens,   // [BATCH+1]
    torch::Tensor const& token_to_seq,  // [MAX_TOKEN_ACROSS_CHUNKS]
    int64_t num_tokens, const std::string& kv_cache_dtype,
49
    torch::Tensor const& scale,
50
51
52
53
54
55
56
57
58
    std::optional<torch::Tensor> seq_starts = std::nullopt);

// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
void cp_gather_cache(
    torch::Tensor const& src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
    torch::Tensor const& dst,          // [TOT_TOKENS, ENTRIES...]
    torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES]
    torch::Tensor const& cu_seq_lens,  // [BATCH+1]
    int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
59

60
61
62
63
64
65
66
67
68
// Gather and upconvert FP8 KV cache to BF16 workspace
void cp_gather_and_upconvert_fp8_kv_cache(
    torch::Tensor const& src_cache,         // [NUM_BLOCKS, BLOCK_SIZE, 656]
    torch::Tensor const& dst,               // [TOT_TOKENS, 576]
    torch::Tensor const& block_table,       // [BATCH, BLOCK_INDICES]
    torch::Tensor const& seq_lens,          // [BATCH]
    torch::Tensor const& workspace_starts,  // [BATCH]
    int64_t batch_size);

69
70
71
72
73
74
75
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
    torch::Tensor& k,             // [num_tokens, head_dim]
    torch::Tensor& kv_cache,      // [num_blocks, block_size, cache_stride]
    torch::Tensor& slot_mapping,  // [num_tokens]
    int64_t quant_block_size,     // quantization block size
    const std::string& scale_fmt);
76

77
78
79
80
81
82
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
    torch::Tensor& ql_nope,  // [num_tokens, num_heads, nope_dim]
    torch::Tensor& q_pe,     // [num_tokens, num_heads, rope_dim]
    torch::Tensor& q_out);   // [num_tokens, num_heads, nope_dim + rope_dim]

83
84
85
86
87
88
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
    const torch::Tensor& kv_cache,  // [num_blocks, block_size, cache_stride]
    torch::Tensor& dst_k,           // [num_tokens, head_dim]
    torch::Tensor& dst_scale,  // [num_tokens, head_dim / quant_block_size * 4]
    const torch::Tensor& block_table,   // [batch_size, num_blocks]
89
    const torch::Tensor& cu_seq_lens);  // [batch_size + 1]