kvcacheio.py 4.75 KB
Newer Older
1
2
from typing import List

3
4
5
import torch


6
7
8
9
10
11
12
def is_hip() -> bool:
    return torch.version.hip is not None


_is_hip = is_hip()


13
14
15
16
17
18
19
20
21
def transfer_kv_per_layer(
    src_k: torch.Tensor,
    dst_k: torch.Tensor,
    src_v: torch.Tensor,
    dst_v: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    item_size: int,
    block_quota: int = 2,
22
    num_warps_per_block: int = 16 if _is_hip else 32,
23
):
Zhiqiang Xie's avatar
Zhiqiang Xie committed
24
25
26
27
28
29
30
31
32
33
34
    torch.ops.sgl_kernel.transfer_kv_per_layer(
        src_k,
        dst_k,
        src_v,
        dst_v,
        src_indices,
        dst_indices,
        item_size,
        block_quota,
        num_warps_per_block,
    )
35
36


37
def transfer_kv_per_layer_pf_lf(
38
39
40
41
42
43
    src_k: torch.Tensor,
    dst_k: torch.Tensor,
    src_v: torch.Tensor,
    dst_v: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
44
    layer_id: int,
45
46
47
    item_size: int,
    src_layout_dim: int,
    block_quota: int = 2,
48
    num_warps_per_block: int = 16 if _is_hip else 32,
49
50
51
52
53
54
55
56
):
    torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
        src_k,
        dst_k,
        src_v,
        dst_v,
        src_indices,
        dst_indices,
57
        layer_id,
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        item_size,
        src_layout_dim,
        block_quota,
        num_warps_per_block,
    )


def transfer_kv_all_layer(
    src_k_layers: torch.Tensor,
    dst_k_layers: torch.Tensor,
    src_v_layers: torch.Tensor,
    dst_v_layers: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
72
73
74
    item_size: int,
    num_layers: int,
    block_quota: int = 2,
75
    num_warps_per_block: int = 16 if _is_hip else 32,
76
):
Zhiqiang Xie's avatar
Zhiqiang Xie committed
77
78
79
80
81
82
83
84
85
86
87
88
    torch.ops.sgl_kernel.transfer_kv_all_layer(
        src_k_layers,
        dst_k_layers,
        src_v_layers,
        dst_v_layers,
        src_indices,
        dst_indices,
        item_size,
        num_layers,
        block_quota,
        num_warps_per_block,
    )
89
90


91
92
93
94
95
96
97
98
99
100
101
def transfer_kv_all_layer_lf_pf(
    src_k_layers: torch.Tensor,
    dst_k: torch.Tensor,
    src_v_layers: torch.Tensor,
    dst_v: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    item_size: int,
    dst_layout_dim: int,
    num_layers: int,
    block_quota: int = 2,
102
    num_warps_per_block: int = 16 if _is_hip else 32,
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
):
    torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
        src_k_layers,
        dst_k,
        src_v_layers,
        dst_v,
        src_indices,
        dst_indices,
        item_size,
        dst_layout_dim,
        num_layers,
        block_quota,
        num_warps_per_block,
    )


def transfer_kv_direct(
    src_layers: List[torch.Tensor],
    dst_layers: List[torch.Tensor],
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    page_size: int,
):
    torch.ops.sgl_kernel.transfer_kv_direct(
        src_layers, dst_layers, src_indices, dst_indices, page_size
    )


131
132
133
134
135
136
137
def transfer_kv_per_layer_mla(
    src: torch.Tensor,
    dst: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    item_size: int,
    block_quota: int = 2,
138
    num_warps_per_block: int = 16 if _is_hip else 32,
139
):
Zhiqiang Xie's avatar
Zhiqiang Xie committed
140
141
142
143
144
145
146
147
148
    torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
        src,
        dst,
        src_indices,
        dst_indices,
        item_size,
        block_quota,
        num_warps_per_block,
    )
149
150


151
def transfer_kv_per_layer_mla_pf_lf(
152
153
154
155
    src: torch.Tensor,
    dst: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
156
    layer_id: int,
157
158
159
    item_size: int,
    src_layout_dim: int,
    block_quota: int = 2,
160
    num_warps_per_block: int = 16 if _is_hip else 32,
161
162
163
164
165
166
):
    torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
        src,
        dst,
        src_indices,
        dst_indices,
167
        layer_id,
168
169
170
171
172
173
174
175
176
177
178
179
        item_size,
        src_layout_dim,
        block_quota,
        num_warps_per_block,
    )


def transfer_kv_all_layer_mla(
    src_layers: torch.Tensor,
    dst_layers: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
180
181
182
    item_size: int,
    num_layers: int,
    block_quota: int = 2,
183
    num_warps_per_block: int = 16 if _is_hip else 32,
184
):
Zhiqiang Xie's avatar
Zhiqiang Xie committed
185
186
187
188
189
190
191
192
193
194
    torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
        src_layers,
        dst_layers,
        src_indices,
        dst_indices,
        item_size,
        num_layers,
        block_quota,
        num_warps_per_block,
    )
195
196
197
198
199
200
201
202
203
204
205


def transfer_kv_all_layer_mla_lf_pf(
    src_layers: torch.Tensor,
    dst: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    item_size: int,
    dst_layout_dim: int,
    num_layers: int,
    block_quota: int = 2,
206
    num_warps_per_block: int = 16 if _is_hip else 32,
207
208
209
210
211
212
213
214
215
216
217
218
):
    torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
        src_layers,
        dst,
        src_indices,
        dst_indices,
        item_size,
        dst_layout_dim,
        num_layers,
        block_quota,
        num_warps_per_block,
    )