cache.h 846 Bytes
Newer Older
1
2
#pragma once

Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
#include <torch/extension.h>

5
6
7
8
#include <map>
#include <vector>

void swap_blocks(
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
  torch::Tensor& src,
  torch::Tensor& dst,
11
  const torch::Tensor& block_mapping);
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
void copy_blocks(
14
15
  std::vector<torch::Tensor>& key_caches,
  std::vector<torch::Tensor>& value_caches,
16
  const torch::Tensor& block_mapping);
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
22
void reshape_and_cache(
  torch::Tensor& key,
  torch::Tensor& value,
  torch::Tensor& key_cache,
  torch::Tensor& value_cache,
23
  torch::Tensor& slot_mapping,
24
25
  const std::string& kv_cache_dtype,
  const float kv_scale);
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
29
30
31
32
33
34
void reshape_and_cache_flash(
  torch::Tensor& key,
  torch::Tensor& value,
  torch::Tensor& key_cache,
  torch::Tensor& value_cache,
  torch::Tensor& slot_mapping,
  const std::string& kv_cache_dtype);

35
// Just for unittest
36
void convert_fp8(
37
38
  torch::Tensor& src_cache,
  torch::Tensor& dst_cache);