"docs/vscode:/vscode.git/clone" did not exist on "dd82ab14972a17cdf4570bee20ad15c10167bd6f"
kvcacheio.py 5.7 KB
Newer Older
1
2
from typing import List

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch


def transfer_kv_per_layer(
    src_k: torch.Tensor,
    dst_k: torch.Tensor,
    src_v: torch.Tensor,
    dst_v: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    io_backend: str,
    page_size: int,
    item_size: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    if io_backend == "kernel":
        torch.ops.sgl_kernel.transfer_kv_per_layer(
            src_k,
            dst_k,
            src_v,
            dst_v,
            src_indices,
            dst_indices,
27
            item_size * src_k.element_size(),  # todo, hot fix for compatibility
28
29
30
31
            block_quota,
            num_warps_per_block,
        )
    elif io_backend == "direct":
32
33
        torch.ops.sgl_kernel.transfer_kv_direct(
            [src_k, src_v], [dst_k, dst_v], src_indices, dst_indices, page_size
34
35
36
37
38
        )
    else:
        raise ValueError(f"Unsupported io backend")


39
def transfer_kv_per_layer_pf_lf(
40
41
42
43
44
45
    src_k: torch.Tensor,
    dst_k: torch.Tensor,
    src_v: torch.Tensor,
    dst_v: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    item_size: int,
    src_layout_dim: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
        src_k,
        dst_k,
        src_v,
        dst_v,
        src_indices,
        dst_indices,
        item_size,
        src_layout_dim,
        block_quota,
        num_warps_per_block,
    )


def transfer_kv_all_layer(
    src_k_layers: torch.Tensor,
    dst_k_layers: torch.Tensor,
    src_v_layers: torch.Tensor,
    dst_v_layers: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
72
73
74
75
76
77
78
79
    io_backend: str,
    item_size: int,
    num_layers: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    if io_backend == "kernel":
        torch.ops.sgl_kernel.transfer_kv_all_layer(
80
81
82
83
            src_k_layers,
            dst_k_layers,
            src_v_layers,
            dst_v_layers,
84
85
86
87
88
89
90
91
            src_indices,
            dst_indices,
            item_size,
            num_layers,
            block_quota,
            num_warps_per_block,
        )
    elif io_backend == "direct":
92
        raise NotImplementedError("Deprecated interface")
93
94
95
96
    else:
        raise ValueError(f"Unsupported io backend")


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def transfer_kv_all_layer_lf_pf(
    src_k_layers: torch.Tensor,
    dst_k: torch.Tensor,
    src_v_layers: torch.Tensor,
    dst_v: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    item_size: int,
    dst_layout_dim: int,
    num_layers: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
        src_k_layers,
        dst_k,
        src_v_layers,
        dst_v,
        src_indices,
        dst_indices,
        item_size,
        dst_layout_dim,
        num_layers,
        block_quota,
        num_warps_per_block,
    )


def transfer_kv_direct(
    src_layers: List[torch.Tensor],
    dst_layers: List[torch.Tensor],
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    page_size: int,
):
    torch.ops.sgl_kernel.transfer_kv_direct(
        src_layers, dst_layers, src_indices, dst_indices, page_size
    )


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def transfer_kv_per_layer_mla(
    src: torch.Tensor,
    dst: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    io_backend: str,
    page_size: int,
    item_size: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    if io_backend == "kernel":
        torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
            src,
            dst,
            src_indices,
            dst_indices,
154
            item_size * src.element_size(),  # todo, hot fix for compatibility
155
156
157
158
            block_quota,
            num_warps_per_block,
        )
    elif io_backend == "direct":
159
160
        torch.ops.sgl_kernel.transfer_kv_direct(
            [src], [dst], src_indices, dst_indices, page_size
161
162
163
164
165
        )
    else:
        raise ValueError(f"Unsupported io backend")


166
def transfer_kv_per_layer_mla_pf_lf(
167
168
169
170
    src: torch.Tensor,
    dst: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    item_size: int,
    src_layout_dim: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
        src,
        dst,
        src_indices,
        dst_indices,
        item_size,
        src_layout_dim,
        block_quota,
        num_warps_per_block,
    )


def transfer_kv_all_layer_mla(
    src_layers: torch.Tensor,
    dst_layers: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
193
194
195
196
197
198
199
200
    io_backend: str,
    item_size: int,
    num_layers: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    if io_backend == "kernel":
        torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
201
202
            src_layers,
            dst_layers,
203
204
205
206
207
208
209
210
            src_indices,
            dst_indices,
            item_size,
            num_layers,
            block_quota,
            num_warps_per_block,
        )
    elif io_backend == "direct":
211
        raise NotImplementedError("Deprecated interface")
212
213
    else:
        raise ValueError(f"Unsupported io backend")
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237


def transfer_kv_all_layer_mla_lf_pf(
    src_layers: torch.Tensor,
    dst: torch.Tensor,
    src_indices: torch.Tensor,
    dst_indices: torch.Tensor,
    item_size: int,
    dst_layout_dim: int,
    num_layers: int,
    block_quota: int = 2,
    num_warps_per_block: int = 32,
):
    torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
        src_layers,
        dst,
        src_indices,
        dst_indices,
        item_size,
        dst_layout_dim,
        num_layers,
        block_quota,
        num_warps_per_block,
    )