Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
b4326330
"examples/vscode:/vscode.git/clone" did not exist on "3c3b729221a34c85eb61788303e9d79aa7795f8b"
Unverified
Commit
b4326330
authored
Jul 23, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 23, 2025
Browse files
Hicache IO kernel refactoring (#8264)
parent
8abd3e77
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
524 additions
and
259 deletions
+524
-259
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+21
-16
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+269
-146
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+38
-23
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+130
-30
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+66
-44
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
b4326330
...
@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer_
direct
(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"transfer_kv_per_layer_
pf_lf
(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int
page
_size) -> ()"
);
"dst_indices, int
item
_size
, int src_layout_dim, int block_quota, int num_warps_per_block
) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_
direct
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_
direct
);
m
.
impl
(
"transfer_kv_per_layer_
pf_lf
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_
pf_lf
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v
, Tensor src_indices, Tensor
"
"transfer_kv_all_layer(Tensor src_k
_layers
, Tensor dst_k
_layers
, Tensor src_v
_layers
, Tensor dst_v
_layers,
"
"dst_indices, int item_size, int num_layers, int
src_layer_offset, int dst_layer_offset, int
block_quota, int "
"
Tensor src_indices, Tensor
dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()"
);
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"dst_indices, int page_size, int num_layers) -> ()"
);
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
m
.
impl
(
"transfer_kv_all_layer_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_direct
);
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_lf_pf
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"
);
"block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer_mla_
direct
(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int
page
_size
)
"
"transfer_kv_per_layer_mla_
pf_lf
(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int
item
_size
,
"
"-> ()"
);
"
int src_layout_dim, int block_quota, int num_warps_per_block)
-> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla_
direct
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_
direct
);
m
.
impl
(
"transfer_kv_per_layer_mla_
pf_lf
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_
pf_lf
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int
item_size, int
"
"transfer_kv_all_layer_mla(Tensor src
_layers
, Tensor dst
_layers
, Tensor src_indices, Tensor dst_indices, int "
"
num_layers, int src_layer_offset
, int
dst
_layer
_offset
, int block_quota, int num_warps_per_block) -> ()"
);
"
item_size
, int
num
_layer
s
, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int num_layers) -> ()"
);
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_direct
);
m
.
impl
(
"transfer_kv_all_layer_mla_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_lf_pf
);
m
.
def
(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()"
);
m
.
impl
(
"transfer_kv_direct"
,
torch
::
kCUDA
,
&
transfer_kv_direct
);
/*
/*
* From csrc/moe/cutlass_moe/w4a8
* From csrc/moe/cutlass_moe/w4a8
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
b4326330
...
@@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_
...
@@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_
}
}
}
}
// todo, structs for different memory layout
template
<
typename
T
>
__device__
__forceinline__
int64_t
__device__
__forceinline__
T
*
get_global_offset_lf
(
get_global_offset_lf
(
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
T
*
base
,
const
uintptr_t
*
__restrict__
/*unused*/
,
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// layer first
// layer first
return
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
;
return
base
+
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
;
}
}
__device__
__forceinline__
int64_t
template
<
typename
T
>
get_global_offset_pf
(
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
__device__
__forceinline__
T
*
get_global_offset_pf
(
T
*
base
,
const
uintptr_t
*
__restrict__
/*unused*/
,
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// page first
// page first
return
page_id
*
page_dim
+
layer_id
*
item_size_bytes
;
return
base
+
page_id
*
page_dim
+
layer_id
*
item_size_bytes
;
}
// get offset from layer base table when layers are not contiguous
template
<
typename
T
>
__device__
__forceinline__
T
*
get_global_offset_lf_tbl
(
T
*
/*unused*/
,
const
uintptr_t
*
__restrict__
layer_base_tbl
,
int64_t
layer_id
,
int64_t
/*unused*/
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
return
reinterpret_cast
<
T
*>
(
layer_base_tbl
[
layer_id
])
+
page_id
*
item_size_bytes
;
}
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
...
@@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl(
...
@@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl(
int64_t
items_per_warp
,
int64_t
items_per_warp
,
int64_t
item_size_bytes
,
int64_t
item_size_bytes
,
int64_t
src_layout_dim
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
)
{
int64_t
dst_layout_dim
,
const
uintptr_t
*
__restrict__
src_k_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_k_layer_tbl
,
const
uintptr_t
*
__restrict__
src_v_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_v_layer_tbl
)
{
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
lane_id
=
tid
%
32
;
int32_t
lane_id
=
tid
%
32
;
int32_t
warp_id
=
tid
/
32
;
int32_t
warp_id
=
tid
/
32
;
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
int
32
_t
item_id
=
warp_id
*
items_per_warp
+
i
;
int
64
_t
item_id
=
warp_id
*
items_per_warp
+
i
;
if
(
item_id
>=
num_items
)
{
if
(
item_id
>=
num_items
)
{
re
turn
;
b
re
ak
;
}
}
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
// Loop over layers if necessary
// Loop over layers if necessary
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
// Calculate offsets using the provided function pointers
const
char
*
src_ptr
=
SrcOffsetFn
(
const
int64_t
src_offset
=
SrcOffsetFn
(
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
static_cast
<
const
char
*>
(
src_k
),
src_k_layer_tbl
,
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
const
int64_t
dst_offset
=
DstOffsetFn
(
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
char
*
dst_ptr
=
DstOffsetFn
(
static_cast
<
char
*>
(
dst_k
),
dst_k_layer_tbl
,
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
transfer_item_warp
(
lane_id
,
src_ptr
,
dst_ptr
,
item_size_bytes
);
if
constexpr
(
IsMLA
)
{
if
constexpr
(
!
IsMLA
)
{
transfer_item_warp
(
const
char
*
src_v_ptr
=
SrcOffsetFn
(
lane_id
,
static_cast
<
const
char
*>
(
src_v
),
src_v_layer_tbl
,
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
char
*
dst_v_ptr
=
DstOffsetFn
(
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
static_cast
<
char
*>
(
dst_v
),
dst_v_layer_tbl
,
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
item_size_bytes
);
transfer_item_warp
(
lane_id
,
src_v_ptr
,
dst_v_ptr
,
item_size_bytes
);
}
else
{
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
item_size_bytes
);
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_v
)
+
src_offset
,
static_cast
<
char
*>
(
dst_v
)
+
dst_offset
,
item_size_bytes
);
}
}
}
}
}
}
...
@@ -103,44 +121,54 @@ void transfer_kv_launcher(
...
@@ -103,44 +121,54 @@ void transfer_kv_launcher(
int64_t
item_size
,
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
,
int64_t
dst_layout_dim
,
const
at
::
Tensor
&
src_k_layers
,
const
at
::
Tensor
&
dst_k_layers
,
const
at
::
Tensor
&
src_v_layers
,
const
at
::
Tensor
&
dst_v_layers
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
src_k
.
scalar_type
()
==
dst_k
.
scalar_type
(),
"Source and destination keys must have the same type"
);
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
TORCH_CHECK
(
dst_indices
.
scalar_type
()
==
at
::
kLong
,
"Destination indices must be of type long"
);
TORCH_CHECK
(
dst_indices
.
scalar_type
()
==
at
::
kLong
,
"Destination indices must be of type long"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
item_size
%
8
==
0
,
"Item byte size must be divisible by 8"
);
if
(
!
IsMLA
)
{
auto
div_up
=
[](
int64_t
x
,
int64_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
TORCH_CHECK
(
src_v
.
scalar_type
()
==
dst_v
.
scalar_type
(),
"Source and destination values must have the same type"
);
}
int
dtype_size
=
src_k
.
element_size
();
TORCH_CHECK
((
item_size
*
dtype_size
)
%
8
==
0
,
"Item byte size must be divisible by 8"
);
auto
div_up
=
[](
int32_t
x
,
int32_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
const
int64_t
num_items
=
src_indices
.
numel
();
const
int64_t
num_items
=
src_indices
.
numel
();
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
const
void
*
src_k_ptr
=
src_k
.
defined
()
?
src_k
.
data_ptr
()
:
nullptr
;
void
*
dst_k_ptr
=
dst_k
.
defined
()
?
dst_k
.
data_ptr
()
:
nullptr
;
const
void
*
src_v_ptr
=
IsMLA
||
!
src_v
.
defined
()
?
nullptr
:
src_v
.
data_ptr
();
void
*
dst_v_ptr
=
IsMLA
||
!
dst_v
.
defined
()
?
nullptr
:
dst_v
.
data_ptr
();
const
uintptr_t
*
src_k_tbl_ptr
=
src_k_layers
.
defined
()
?
src_k_layers
.
data_ptr
<
uintptr_t
>
()
:
nullptr
;
const
uintptr_t
*
dst_k_tbl_ptr
=
dst_k_layers
.
defined
()
?
dst_k_layers
.
data_ptr
<
uintptr_t
>
()
:
nullptr
;
const
uintptr_t
*
src_v_tbl_ptr
=
IsMLA
||
!
src_v_layers
.
defined
()
?
nullptr
:
src_v_layers
.
data_ptr
<
uintptr_t
>
();
const
uintptr_t
*
dst_v_tbl_ptr
=
IsMLA
||
!
dst_v_layers
.
defined
()
?
nullptr
:
dst_v_layers
.
data_ptr
<
uintptr_t
>
();
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k
.
data
_ptr
()
,
src_k_ptr
,
dst_k
.
data
_ptr
()
,
dst_k_ptr
,
(
IsMLA
?
nullptr
:
src_v
.
data
_ptr
())
,
src_v
_ptr
,
(
IsMLA
?
nullptr
:
dst_v
.
data
_ptr
())
,
dst_v
_ptr
,
src_indices
.
data_ptr
<
int64_t
>
(),
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
start_layer_id
,
num_layers_to_process
,
num_layers_to_process
,
num_items
,
num_items
,
items_per_warp
,
items_per_warp
,
item_size
*
dtype_size
,
item_size
,
src_layout_dim
*
dtype_size
,
src_layout_dim
,
dst_layout_dim
*
dtype_size
);
dst_layout_dim
,
src_k_tbl_ptr
,
dst_k_tbl_ptr
,
src_v_tbl_ptr
,
dst_v_tbl_ptr
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
...
@@ -154,11 +182,28 @@ void transfer_kv_per_layer(
...
@@ -154,11 +182,28 @@ void transfer_kv_per_layer(
int64_t
item_size
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
at
::
Tensor
empty
;
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
block_quota
,
num_warps_per_block
);
transfer_kv_launcher
<
get_global_offset_lf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
}
void
transfer_kv_
all
_layer
(
void
transfer_kv_
per
_layer
_pf_lf
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
,
...
@@ -166,12 +211,11 @@ void transfer_kv_all_layer(
...
@@ -166,12 +211,11 @@ void transfer_kv_all_layer(
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layout_dim
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_pf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
false
>
(
src_k
,
src_k
,
dst_k
,
dst_k
,
src_v
,
src_v
,
...
@@ -179,10 +223,81 @@ void transfer_kv_all_layer(
...
@@ -179,10 +223,81 @@ void transfer_kv_all_layer(
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
0
,
0
,
1
,
item_size
,
src_layout_dim
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k_layers
,
const
at
::
Tensor
dst_k_layers
,
const
at
::
Tensor
src_v_layers
,
const
at
::
Tensor
dst_v_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_k_layers
.
size
(
0
),
"Number of layers in source k tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_lf_tbl
<
char
>
,
false
>
(
empty
,
empty
,
empty
,
empty
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
0
,
0
,
src_k_layers
,
dst_k_layers
,
src_v_layers
,
dst_v_layers
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_lf_pf
(
const
at
::
Tensor
src_k_layers
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v_layers
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_k_layers
.
size
(
0
),
"Number of layers in source k tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_pf
<
char
>
,
false
>
(
empty
,
dst_k
,
empty
,
dst_v
,
src_indices
,
dst_indices
,
0
,
num_layers
,
num_layers
,
item_size
,
item_size
,
src_layer_offset
,
0
,
dst_layer_offset
,
dst_layout_dim
,
src_k_layers
,
empty
,
src_v_layers
,
empty
,
block_quota
,
block_quota
,
num_warps_per_block
);
num_warps_per_block
);
}
}
...
@@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla(
...
@@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla(
int64_t
item_size
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty
_tensor
=
at
::
Tensor
()
;
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
true
>
(
transfer_kv_launcher
<
get_global_offset_lf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
true
>
(
src
,
src
,
dst
,
dst
,
empty
_tensor
,
empty
,
empty
_tensor
,
empty
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
0
,
0
,
...
@@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla(
...
@@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla(
item_size
,
item_size
,
0
,
0
,
0
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
block_quota
,
num_warps_per_block
);
num_warps_per_block
);
}
}
void
transfer_kv_
all
_layer_mla
(
void
transfer_kv_
per
_layer_mla
_pf_lf
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layout_dim
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty
_tensor
=
at
::
Tensor
()
;
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_
lf
,
get_global_offset_lf
,
true
>
(
transfer_kv_launcher
<
get_global_offset_
pf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
true
>
(
src
,
src
,
dst
,
dst
,
empty_tensor
,
empty
,
empty_tensor
,
empty
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
src_layout_dim
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src_layers
,
const
at
::
Tensor
dst_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_layers
.
size
(
0
),
"Number of layers in source tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_lf_tbl
<
char
>
,
true
>
(
empty
,
empty
,
empty
,
empty
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
0
,
0
,
src_layers
,
dst_layers
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_mla_lf_pf
(
const
at
::
Tensor
src_layers
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_layers
.
size
(
0
),
"Number of layers in source tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_pf
<
char
>
,
true
>
(
empty
,
dst
,
empty
,
empty
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
0
,
0
,
num_layers
,
num_layers
,
item_size
,
item_size
,
src_layer_offset
,
0
,
dst_layer_offset
,
dst_layout_dim
,
src_layers
,
empty
,
empty
,
empty
,
block_quota
,
block_quota
,
num_warps_per_block
);
num_warps_per_block
);
}
}
inline
void
transfer_page_direct
(
inline
void
transfer_page_direct
(
const
at
::
Tensor
src_buffer
,
const
at
::
Tensor
&
src_buffer
,
at
::
Tensor
dst_buffer
,
at
::
Tensor
&
dst_buffer
,
int64_t
src_page_index
,
int64_t
src_page_index
,
int64_t
dst_page_index
,
int64_t
dst_page_index
,
int64_t
page_size
)
{
int64_t
page_size
)
{
...
@@ -252,16 +436,14 @@ inline void transfer_page_direct(
...
@@ -252,16 +436,14 @@ inline void transfer_page_direct(
/* non_blocking= */
true
);
/* non_blocking= */
true
);
}
}
template
<
bool
IsMLA
,
bool
AllLayers
>
void
transfer_kv_direct
(
inline
void
transfer_kv_direct_impl
(
const
std
::
vector
<
at
::
Tensor
>&
src_layers
,
const
at
::
Tensor
&
src_k
,
std
::
vector
<
at
::
Tensor
>
dst_layers
,
at
::
Tensor
&
dst_k
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
&
src_v_opt
,
// Only used when IsMLA is false (for src_v)
const
at
::
Tensor
dst_indices
,
at
::
Tensor
&
dst_v_opt
,
// Only used when IsMLA is false (for dst_v)
int64_t
page_size
)
{
const
at
::
Tensor
&
src_indices
,
TORCH_CHECK
(
const
at
::
Tensor
&
dst_indices
,
src_layers
.
size
()
==
dst_layers
.
size
(),
"Source and destination layers must have the same number of layers"
);
int64_t
page_size
,
int64_t
num_layers
=
1
)
{
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
...
@@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl(
...
@@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl(
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
const
int64_t
num_layers
=
src_layers
.
size
();
for
(
const
auto
i
:
c10
::
irange
(
num_pages
)
)
{
for
(
int64_t
i
=
0
;
i
<
num_pages
;
++
i
)
{
auto
s_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
s
rc
_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d
st
_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
if
constexpr
(
AllLayers
)
{
for
(
int64_t
j
=
0
;
j
<
num_layers
;
++
j
)
{
for
(
const
auto
j
:
c10
::
irange
(
num_layers
))
{
transfer_page_direct
(
src_layers
[
j
],
dst_layers
[
j
],
src_index
,
dst_index
,
page_size
);
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
.
select
(
0
,
j
),
dst_v_opt
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
}
}
else
{
// Per-layer
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
,
dst_v_opt
,
s_index
,
d_index
,
page_size
);
}
}
}
}
}
}
}
void
transfer_kv_per_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
transfer_kv_direct_impl
<
false
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
transfer_kv_direct_impl
<
false
,
true
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
void
transfer_kv_per_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
false
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
b4326330
...
@@ -399,38 +399,42 @@ void transfer_kv_per_layer(
...
@@ -399,38 +399,42 @@ void transfer_kv_per_layer(
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_
direct
(
void
transfer_kv_per_layer_
pf_lf
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer
(
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
_layers
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
dst_k
_layers
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
_layers
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
dst_v
_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_
direct
(
void
transfer_kv_all_layer_
lf_pf
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
_layers
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
_layers
,
at
::
Tensor
dst_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
item_size
,
int64_t
num_layers
);
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla
(
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
,
...
@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
...
@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla_
direct
(
void
transfer_kv_per_layer_mla_
pf_lf
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_mla
(
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
_layers
,
at
::
Tensor
dst
,
const
at
::
Tensor
dst
_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_mla_
direct
(
void
transfer_kv_all_layer_mla_
lf_pf
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
_layers
,
at
::
Tensor
dst
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
item_size
,
int64_t
num_layers
);
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_direct
(
const
std
::
vector
<
at
::
Tensor
>&
src_layers
,
std
::
vector
<
at
::
Tensor
>
dst_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
/*
/*
* From csrc/moe/cutlass_moe/w4a8
* From csrc/moe/cutlass_moe/w4a8
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
b4326330
from
typing
import
List
import
torch
import
torch
...
@@ -22,57 +24,116 @@ def transfer_kv_per_layer(
...
@@ -22,57 +24,116 @@ def transfer_kv_per_layer(
dst_v
,
dst_v
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
*
src_k
.
element_size
(),
# todo, hot fix for compatibility
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_
per_layer_
direct
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
[
src_k
,
src_v
],
[
dst_k
,
dst_v
]
,
src_indices
,
dst_indices
,
page_size
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_
all
_layer
(
def
transfer_kv_
per
_layer
_pf_lf
(
src_k
:
torch
.
Tensor
,
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
src_layout_dim
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_pf_lf
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
src_layout_dim
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_all_layer
(
src_k_layers
:
torch
.
Tensor
,
dst_k_layers
:
torch
.
Tensor
,
src_v_layers
:
torch
.
Tensor
,
dst_v_layers
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k
,
src_k
_layers
,
dst_k
,
dst_k
_layers
,
src_v
,
src_v
_layers
,
dst_v
,
dst_v
_layers
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
,
num_layers
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct
(
raise
NotImplementedError
(
"Deprecated interface"
)
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_lf_pf
(
src_k_layers
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v_layers
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
dst_layout_dim
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_lf_pf
(
src_k_layers
,
dst_k
,
src_v_layers
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
dst_layout_dim
,
num_layers
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_direct
(
src_layers
:
List
[
torch
.
Tensor
],
dst_layers
:
List
[
torch
.
Tensor
],
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
src_layers
,
dst_layers
,
src_indices
,
dst_indices
,
page_size
)
def
transfer_kv_per_layer_mla
(
def
transfer_kv_per_layer_mla
(
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
...
@@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla(
...
@@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla(
dst
,
dst
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
*
src
.
element_size
(),
# todo, hot fix for compatibility
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_
per_layer_mla_
direct
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
src
,
dst
,
src_indices
,
dst_indices
,
page_size
[
src
]
,
[
dst
]
,
src_indices
,
dst_indices
,
page_size
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_
all
_layer_mla
(
def
transfer_kv_
per
_layer_mla
_pf_lf
(
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
src_layout_dim
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_pf_lf
(
src
,
dst
,
src_indices
,
dst_indices
,
item_size
,
src_layout_dim
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_all_layer_mla
(
src_layers
:
torch
.
Tensor
,
dst_layers
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
src
,
src
_layers
,
dst
,
dst
_layers
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
,
num_layers
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_direct
(
raise
NotImplementedError
(
"Deprecated interface"
)
src
,
dst
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_mla_lf_pf
(
src_layers
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
dst_layout_dim
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_lf_pf
(
src_layers
,
dst
,
src_indices
,
dst_indices
,
item_size
,
dst_layout_dim
,
num_layers
,
block_quota
,
num_warps_per_block
,
)
sgl-kernel/tests/test_kvcacheio.py
View file @
b4326330
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
sgl_kernel.kvcacheio
import
(
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer
,
transfer_kv_all_layer_mla
,
transfer_kv_all_layer_mla
,
transfer_kv_direct
,
transfer_kv_per_layer
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
transfer_kv_per_layer_mla
,
)
)
...
@@ -104,14 +105,12 @@ def test_transfer_kv(
...
@@ -104,14 +105,12 @@ def test_transfer_kv(
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
item_size
=
item_size
,
)
)
transfer_kv_
per_layer_mla
(
transfer_kv_
direct
(
src_pool_host
[
layer_idx_to_test
],
[
src_pool_host
[
layer_idx_to_test
]
]
,
dst_pool_direct
[
layer_idx_to_test
],
[
dst_pool_direct
[
layer_idx_to_test
]
]
,
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
)
)
else
:
else
:
for
layer_id
in
range
(
num_layers
):
for
layer_id
in
range
(
num_layers
):
...
@@ -121,29 +120,34 @@ def test_transfer_kv(
...
@@ -121,29 +120,34 @@ def test_transfer_kv(
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
)
)
src_layers_device
=
torch
.
tensor
(
[
src_pool_host
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_layers_device
=
torch
.
tensor
(
[
dst_pool_kernel
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
transfer_kv_all_layer_mla
(
transfer_kv_all_layer_mla
(
src_
pool_host
,
src_
layers_device
,
dst_
pool_kernel
,
dst_
layers_device
,
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
*
dtype
.
itemsize
,
item_size
=
item_size
,
num_layers
=
num_layers
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
transfer_kv_
all_layer_mla
(
transfer_kv_
direct
(
src_pool_host
,
[
src_pool_host
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
dst_pool_direct
,
[
dst_pool_direct
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_pool_kernel
,
dst_pool_ref
)
torch
.
testing
.
assert_close
(
dst_pool_kernel
,
dst_pool_ref
)
...
@@ -173,16 +177,15 @@ def test_transfer_kv(
...
@@ -173,16 +177,15 @@ def test_transfer_kv(
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
item_size
=
item_size
,
)
)
transfer_kv_per_layer
(
transfer_kv_direct
(
src_k_pool
[
layer_idx_to_test
],
[
src_k_pool
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
]],
dst_k_pool_direct
[
layer_idx_to_test
],
[
src_v_pool
[
layer_idx_to_test
],
dst_k_pool_direct
[
layer_idx_to_test
],
dst_v_pool_direct
[
layer_idx_to_test
],
dst_v_pool_direct
[
layer_idx_to_test
],
],
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
)
)
else
:
else
:
for
layer_id
in
range
(
num_layers
):
for
layer_id
in
range
(
num_layers
):
...
@@ -198,33 +201,52 @@ def test_transfer_kv(
...
@@ -198,33 +201,52 @@ def test_transfer_kv(
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
)
)
src_k_layers_device
=
torch
.
tensor
(
[
src_k_pool
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)],
dtype
=
torch
.
uint64
,
device
=
device
,
)
src_v_layers_device
=
torch
.
tensor
(
[
src_v_pool
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_k_layers_device
=
torch
.
tensor
(
[
dst_k_pool_kernel
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_v_layers_device
=
torch
.
tensor
(
[
dst_v_pool_kernel
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
transfer_kv_all_layer
(
transfer_kv_all_layer
(
src_k_
pool
,
src_k_
layers_device
,
dst_k_
pool_kernel
,
dst_k_
layers_device
,
src_v_
pool
,
src_v_
layers_device
,
dst_v_
pool_kernel
,
dst_v_
layers_device
,
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
*
dtype
.
itemsize
,
item_size
=
item_size
,
num_layers
=
num_layers
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
transfer_kv_
all_layer
(
transfer_kv_
direct
(
src_k_pool
,
[
src_k_pool
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
dst_k_pool_direct
,
+
[
src_v_pool
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
src_v_pool
,
[
dst_k_pool_direct
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
dst_v_pool_direct
,
+
[
dst_v_pool_direct
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment