tree_decoding_utils.py 1.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
import torch

from vllm.attention.backends.blocksparse_attn import BlocksparseFlashAttentionImpl
from vllm import _custom_ops as ops
from vllm.attention.ops.paged_attn import PagedAttention

def move_cache(
            backend,
            kv_caches: List[torch.Tensor],
            src_to_dists: torch.Tensor,
            kv_cache_dtype: str,
            num_kv_heads: int,
            head_size: int,
    ) -> None:
    if backend.get_name() == "rocm-flash-attn"  or \
        backend.get_name() == "xformers":

        key_caches = []
        value_caches = [] 

        num_layers = len(kv_caches)
        token_num = src_to_dists.shape[0]
            
        tmp_store_kv = torch.empty(
                    (2, num_layers, token_num, num_kv_heads, head_size),
                    dtype=kv_caches[0].dtype, device=kv_caches[0].device)
        keys = tmp_store_kv[0].contiguous()
        values = tmp_store_kv[1].contiguous()

        for kv_cache in kv_caches:            
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, num_kv_heads, head_size)
            key_caches.append(key_cache)
            value_caches.append(value_cache)

        ops.read_cache(
            keys,
            values,
            key_caches,
            value_caches,
            src_to_dists[:, 0].contiguous(),
            kv_cache_dtype
        )

        ops.write_cache_multi_layers(
            keys,
            values,
            key_caches,
            value_caches,
            src_to_dists[:, 1].contiguous(),
            kv_cache_dtype
        )
    else:
        raise NotImplementedError("Only BlocksparseFlashAttention/ROCmFlash/XFormers backends support move cache for now!")