Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b4326330
Unverified
Commit
b4326330
authored
Jul 23, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 23, 2025
Browse files
Hicache IO kernel refactoring (#8264)
parent
8abd3e77
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
524 additions
and
259 deletions
+524
-259
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+21
-16
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+269
-146
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+38
-23
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+130
-30
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+66
-44
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
b4326330
...
@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -249,34 +249,39 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
impl
(
"transfer_kv_per_layer"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer_
direct
(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"transfer_kv_per_layer_
pf_lf
(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int
page
_size) -> ()"
);
"dst_indices, int
item
_size
, int src_layout_dim, int block_quota, int num_warps_per_block
) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_
direct
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_
direct
);
m
.
impl
(
"transfer_kv_per_layer_
pf_lf
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_
pf_lf
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v
, Tensor src_indices, Tensor
"
"transfer_kv_all_layer(Tensor src_k
_layers
, Tensor dst_k
_layers
, Tensor src_v
_layers
, Tensor dst_v
_layers,
"
"dst_indices, int item_size, int num_layers, int
src_layer_offset, int dst_layer_offset, int
block_quota, int "
"
Tensor src_indices, Tensor
dst_indices, int item_size, int num_layers, int block_quota, int "
"num_warps_per_block) -> ()"
);
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
impl
(
"transfer_kv_all_layer"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer_direct(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
"dst_indices, int page_size, int num_layers) -> ()"
);
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
m
.
impl
(
"transfer_kv_all_layer_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_direct
);
"num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_lf_pf
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
"block_quota, int num_warps_per_block) -> ()"
);
"block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
impl
(
"transfer_kv_per_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla
);
m
.
def
(
m
.
def
(
"transfer_kv_per_layer_mla_
direct
(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int
page
_size
)
"
"transfer_kv_per_layer_mla_
pf_lf
(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int
item
_size
,
"
"-> ()"
);
"
int src_layout_dim, int block_quota, int num_warps_per_block)
-> ()"
);
m
.
impl
(
"transfer_kv_per_layer_mla_
direct
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_
direct
);
m
.
impl
(
"transfer_kv_per_layer_mla_
pf_lf
"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_mla_
pf_lf
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int
item_size, int
"
"transfer_kv_all_layer_mla(Tensor src
_layers
, Tensor dst
_layers
, Tensor src_indices, Tensor dst_indices, int "
"
num_layers, int src_layer_offset
, int
dst
_layer
_offset
, int block_quota, int num_warps_per_block) -> ()"
);
"
item_size
, int
num
_layer
s
, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
impl
(
"transfer_kv_all_layer_mla"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla
);
m
.
def
(
m
.
def
(
"transfer_kv_all_layer_mla_direct(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int page_size, "
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
"int num_layers) -> ()"
);
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()"
);
m
.
impl
(
"transfer_kv_all_layer_mla_direct"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_direct
);
m
.
impl
(
"transfer_kv_all_layer_mla_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_mla_lf_pf
);
m
.
def
(
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()"
);
m
.
impl
(
"transfer_kv_direct"
,
torch
::
kCUDA
,
&
transfer_kv_direct
);
/*
/*
* From csrc/moe/cutlass_moe/w4a8
* From csrc/moe/cutlass_moe/w4a8
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
b4326330
...
@@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_
...
@@ -22,17 +22,40 @@ transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_
}
}
}
}
// todo, structs for different memory layout
template
<
typename
T
>
__device__
__forceinline__
int64_t
__device__
__forceinline__
T
*
get_global_offset_lf
(
get_global_offset_lf
(
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
T
*
base
,
const
uintptr_t
*
__restrict__
/*unused*/
,
int64_t
layer_id
,
int64_t
layer_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// layer first
// layer first
return
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
;
return
base
+
layer_id
*
layer_dim
+
page_id
*
item_size_bytes
;
}
}
__device__
__forceinline__
int64_t
template
<
typename
T
>
get_global_offset_pf
(
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
__device__
__forceinline__
T
*
get_global_offset_pf
(
T
*
base
,
const
uintptr_t
*
__restrict__
/*unused*/
,
int64_t
layer_id
,
int64_t
page_dim
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
// page first
// page first
return
page_id
*
page_dim
+
layer_id
*
item_size_bytes
;
return
base
+
page_id
*
page_dim
+
layer_id
*
item_size_bytes
;
}
// get offset from layer base table when layers are not contiguous
template
<
typename
T
>
__device__
__forceinline__
T
*
get_global_offset_lf_tbl
(
T
*
/*unused*/
,
const
uintptr_t
*
__restrict__
layer_base_tbl
,
int64_t
layer_id
,
int64_t
/*unused*/
,
int64_t
page_id
,
int64_t
item_size_bytes
)
{
return
reinterpret_cast
<
T
*>
(
layer_base_tbl
[
layer_id
])
+
page_id
*
item_size_bytes
;
}
}
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
template
<
auto
SrcOffsetFn
,
auto
DstOffsetFn
,
bool
IsMLA
>
...
@@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl(
...
@@ -49,42 +72,37 @@ __global__ void transfer_kernel_impl(
int64_t
items_per_warp
,
int64_t
items_per_warp
,
int64_t
item_size_bytes
,
int64_t
item_size_bytes
,
int64_t
src_layout_dim
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
)
{
int64_t
dst_layout_dim
,
const
uintptr_t
*
__restrict__
src_k_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_k_layer_tbl
,
const
uintptr_t
*
__restrict__
src_v_layer_tbl
,
const
uintptr_t
*
__restrict__
dst_v_layer_tbl
)
{
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int32_t
lane_id
=
tid
%
32
;
int32_t
lane_id
=
tid
%
32
;
int32_t
warp_id
=
tid
/
32
;
int32_t
warp_id
=
tid
/
32
;
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
for
(
int
i
=
0
;
i
<
items_per_warp
;
++
i
)
{
int
32
_t
item_id
=
warp_id
*
items_per_warp
+
i
;
int
64
_t
item_id
=
warp_id
*
items_per_warp
+
i
;
if
(
item_id
>=
num_items
)
{
if
(
item_id
>=
num_items
)
{
re
turn
;
b
re
ak
;
}
}
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
src_page_id
=
src_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
const
int64_t
dst_page_id
=
dst_indices
[
item_id
];
// Loop over layers if necessary
// Loop over layers if necessary
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
for
(
int64_t
layer_id
=
start_layer_id
;
layer_id
<
start_layer_id
+
num_layers_to_process
;
++
layer_id
)
{
// Calculate offsets using the provided function pointers
const
char
*
src_ptr
=
SrcOffsetFn
(
const
int64_t
src_offset
=
SrcOffsetFn
(
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
static_cast
<
const
char
*>
(
src_k
),
src_k_layer_tbl
,
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
const
int64_t
dst_offset
=
DstOffsetFn
(
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
char
*
dst_ptr
=
DstOffsetFn
(
static_cast
<
char
*>
(
dst_k
),
dst_k_layer_tbl
,
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
transfer_item_warp
(
lane_id
,
src_ptr
,
dst_ptr
,
item_size_bytes
);
if
constexpr
(
IsMLA
)
{
if
constexpr
(
!
IsMLA
)
{
transfer_item_warp
(
const
char
*
src_v_ptr
=
SrcOffsetFn
(
lane_id
,
static_cast
<
const
char
*>
(
src_v
),
src_v_layer_tbl
,
layer_id
,
src_layout_dim
,
src_page_id
,
item_size_bytes
);
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
char
*
dst_v_ptr
=
DstOffsetFn
(
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
static_cast
<
char
*>
(
dst_v
),
dst_v_layer_tbl
,
layer_id
,
dst_layout_dim
,
dst_page_id
,
item_size_bytes
);
item_size_bytes
);
transfer_item_warp
(
lane_id
,
src_v_ptr
,
dst_v_ptr
,
item_size_bytes
);
}
else
{
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_k
)
+
src_offset
,
static_cast
<
char
*>
(
dst_k
)
+
dst_offset
,
item_size_bytes
);
transfer_item_warp
(
lane_id
,
static_cast
<
const
char
*>
(
src_v
)
+
src_offset
,
static_cast
<
char
*>
(
dst_v
)
+
dst_offset
,
item_size_bytes
);
}
}
}
}
}
}
...
@@ -103,44 +121,54 @@ void transfer_kv_launcher(
...
@@ -103,44 +121,54 @@ void transfer_kv_launcher(
int64_t
item_size
,
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
src_layout_dim
,
int64_t
dst_layout_dim
,
int64_t
dst_layout_dim
,
const
at
::
Tensor
&
src_k_layers
,
const
at
::
Tensor
&
dst_k_layers
,
const
at
::
Tensor
&
src_v_layers
,
const
at
::
Tensor
&
dst_v_layers
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
src_k
.
scalar_type
()
==
dst_k
.
scalar_type
(),
"Source and destination keys must have the same type"
);
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
is_cuda
(),
"Source indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
dst_indices
.
is_cuda
(),
"Destination indices must be a CUDA tensor"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
TORCH_CHECK
(
src_indices
.
scalar_type
()
==
at
::
kLong
,
"Source indices must be of type long"
);
TORCH_CHECK
(
dst_indices
.
scalar_type
()
==
at
::
kLong
,
"Destination indices must be of type long"
);
TORCH_CHECK
(
dst_indices
.
scalar_type
()
==
at
::
kLong
,
"Destination indices must be of type long"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
item_size
%
8
==
0
,
"Item byte size must be divisible by 8"
);
if
(
!
IsMLA
)
{
auto
div_up
=
[](
int64_t
x
,
int64_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
TORCH_CHECK
(
src_v
.
scalar_type
()
==
dst_v
.
scalar_type
(),
"Source and destination values must have the same type"
);
}
int
dtype_size
=
src_k
.
element_size
();
TORCH_CHECK
((
item_size
*
dtype_size
)
%
8
==
0
,
"Item byte size must be divisible by 8"
);
auto
div_up
=
[](
int32_t
x
,
int32_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
};
const
int64_t
num_items
=
src_indices
.
numel
();
const
int64_t
num_items
=
src_indices
.
numel
();
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int64_t
items_per_warp
=
div_up
(
num_items
,
block_quota
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
const
int32_t
num_blocks
=
div_up
(
num_items
,
items_per_warp
*
num_warps_per_block
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
dim3
grid_dim
(
num_blocks
,
1
,
1
);
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
const
int32_t
threads_per_block
=
num_warps_per_block
*
32
;
const
void
*
src_k_ptr
=
src_k
.
defined
()
?
src_k
.
data_ptr
()
:
nullptr
;
void
*
dst_k_ptr
=
dst_k
.
defined
()
?
dst_k
.
data_ptr
()
:
nullptr
;
const
void
*
src_v_ptr
=
IsMLA
||
!
src_v
.
defined
()
?
nullptr
:
src_v
.
data_ptr
();
void
*
dst_v_ptr
=
IsMLA
||
!
dst_v
.
defined
()
?
nullptr
:
dst_v
.
data_ptr
();
const
uintptr_t
*
src_k_tbl_ptr
=
src_k_layers
.
defined
()
?
src_k_layers
.
data_ptr
<
uintptr_t
>
()
:
nullptr
;
const
uintptr_t
*
dst_k_tbl_ptr
=
dst_k_layers
.
defined
()
?
dst_k_layers
.
data_ptr
<
uintptr_t
>
()
:
nullptr
;
const
uintptr_t
*
src_v_tbl_ptr
=
IsMLA
||
!
src_v_layers
.
defined
()
?
nullptr
:
src_v_layers
.
data_ptr
<
uintptr_t
>
();
const
uintptr_t
*
dst_v_tbl_ptr
=
IsMLA
||
!
dst_v_layers
.
defined
()
?
nullptr
:
dst_v_layers
.
data_ptr
<
uintptr_t
>
();
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
transfer_kernel_impl
<
SrcOffsetFn
,
DstOffsetFn
,
IsMLA
><<<
grid_dim
,
threads_per_block
,
0
,
torch_current_stream
>>>
(
src_k
.
data
_ptr
()
,
src_k_ptr
,
dst_k
.
data
_ptr
()
,
dst_k_ptr
,
(
IsMLA
?
nullptr
:
src_v
.
data
_ptr
())
,
src_v
_ptr
,
(
IsMLA
?
nullptr
:
dst_v
.
data
_ptr
())
,
dst_v
_ptr
,
src_indices
.
data_ptr
<
int64_t
>
(),
src_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
dst_indices
.
data_ptr
<
int64_t
>
(),
start_layer_id
,
start_layer_id
,
num_layers_to_process
,
num_layers_to_process
,
num_items
,
num_items
,
items_per_warp
,
items_per_warp
,
item_size
*
dtype_size
,
item_size
,
src_layout_dim
*
dtype_size
,
src_layout_dim
,
dst_layout_dim
*
dtype_size
);
dst_layout_dim
,
src_k_tbl_ptr
,
dst_k_tbl_ptr
,
src_v_tbl_ptr
,
dst_v_tbl_ptr
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
...
@@ -154,11 +182,28 @@ void transfer_kv_per_layer(
...
@@ -154,11 +182,28 @@ void transfer_kv_per_layer(
int64_t
item_size
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
at
::
Tensor
empty
;
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
block_quota
,
num_warps_per_block
);
transfer_kv_launcher
<
get_global_offset_lf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
0
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
}
void
transfer_kv_
all
_layer
(
void
transfer_kv_
per
_layer
_pf_lf
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
,
...
@@ -166,12 +211,11 @@ void transfer_kv_all_layer(
...
@@ -166,12 +211,11 @@ void transfer_kv_all_layer(
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layout_dim
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
false
>
(
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_pf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
false
>
(
src_k
,
src_k
,
dst_k
,
dst_k
,
src_v
,
src_v
,
...
@@ -179,10 +223,81 @@ void transfer_kv_all_layer(
...
@@ -179,10 +223,81 @@ void transfer_kv_all_layer(
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
0
,
0
,
1
,
item_size
,
src_layout_dim
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k_layers
,
const
at
::
Tensor
dst_k_layers
,
const
at
::
Tensor
src_v_layers
,
const
at
::
Tensor
dst_v_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_k_layers
.
size
(
0
),
"Number of layers in source k tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_lf_tbl
<
char
>
,
false
>
(
empty
,
empty
,
empty
,
empty
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
0
,
0
,
src_k_layers
,
dst_k_layers
,
src_v_layers
,
dst_v_layers
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_lf_pf
(
const
at
::
Tensor
src_k_layers
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v_layers
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_k_layers
.
size
(
0
),
"Number of layers in source k tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_pf
<
char
>
,
false
>
(
empty
,
dst_k
,
empty
,
dst_v
,
src_indices
,
dst_indices
,
0
,
num_layers
,
num_layers
,
item_size
,
item_size
,
src_layer_offset
,
0
,
dst_layer_offset
,
dst_layout_dim
,
src_k_layers
,
empty
,
src_v_layers
,
empty
,
block_quota
,
block_quota
,
num_warps_per_block
);
num_warps_per_block
);
}
}
...
@@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla(
...
@@ -195,12 +310,12 @@ void transfer_kv_per_layer_mla(
int64_t
item_size
,
int64_t
item_size
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty
_tensor
=
at
::
Tensor
()
;
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf
,
get_global_offset_lf
,
true
>
(
transfer_kv_launcher
<
get_global_offset_lf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
true
>
(
src
,
src
,
dst
,
dst
,
empty
_tensor
,
empty
,
empty
_tensor
,
empty
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
0
,
0
,
...
@@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla(
...
@@ -208,41 +323,110 @@ void transfer_kv_per_layer_mla(
item_size
,
item_size
,
0
,
0
,
0
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
block_quota
,
num_warps_per_block
);
num_warps_per_block
);
}
}
void
transfer_kv_
all
_layer_mla
(
void
transfer_kv_
per
_layer_mla
_pf_lf
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
src_layout_dim
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
int64_t
num_warps_per_block
)
{
at
::
Tensor
empty
_tensor
=
at
::
Tensor
()
;
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_
lf
,
get_global_offset_lf
,
true
>
(
transfer_kv_launcher
<
get_global_offset_
pf
<
const
char
>
,
get_global_offset_lf
<
char
>
,
true
>
(
src
,
src
,
dst
,
dst
,
empty_tensor
,
empty
,
empty_tensor
,
empty
,
src_indices
,
dst_indices
,
0
,
1
,
item_size
,
src_layout_dim
,
0
,
empty
,
empty
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src_layers
,
const
at
::
Tensor
dst_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_layers
.
size
(
0
),
"Number of layers in source tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_lf_tbl
<
char
>
,
true
>
(
empty
,
empty
,
empty
,
empty
,
src_indices
,
dst_indices
,
0
,
num_layers
,
item_size
,
0
,
0
,
src_layers
,
dst_layers
,
empty
,
empty
,
block_quota
,
num_warps_per_block
);
}
void
transfer_kv_all_layer_mla_lf_pf
(
const
at
::
Tensor
src_layers
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
)
{
TORCH_CHECK
(
num_layers
==
src_layers
.
size
(
0
),
"Number of layers in source tensor does not match num_layers"
);
at
::
Tensor
empty
;
transfer_kv_launcher
<
get_global_offset_lf_tbl
<
const
char
>
,
get_global_offset_pf
<
char
>
,
true
>
(
empty
,
dst
,
empty
,
empty
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
0
,
0
,
num_layers
,
num_layers
,
item_size
,
item_size
,
src_layer_offset
,
0
,
dst_layer_offset
,
dst_layout_dim
,
src_layers
,
empty
,
empty
,
empty
,
block_quota
,
block_quota
,
num_warps_per_block
);
num_warps_per_block
);
}
}
inline
void
transfer_page_direct
(
inline
void
transfer_page_direct
(
const
at
::
Tensor
src_buffer
,
const
at
::
Tensor
&
src_buffer
,
at
::
Tensor
dst_buffer
,
at
::
Tensor
&
dst_buffer
,
int64_t
src_page_index
,
int64_t
src_page_index
,
int64_t
dst_page_index
,
int64_t
dst_page_index
,
int64_t
page_size
)
{
int64_t
page_size
)
{
...
@@ -252,16 +436,14 @@ inline void transfer_page_direct(
...
@@ -252,16 +436,14 @@ inline void transfer_page_direct(
/* non_blocking= */
true
);
/* non_blocking= */
true
);
}
}
template
<
bool
IsMLA
,
bool
AllLayers
>
void
transfer_kv_direct
(
inline
void
transfer_kv_direct_impl
(
const
std
::
vector
<
at
::
Tensor
>&
src_layers
,
const
at
::
Tensor
&
src_k
,
std
::
vector
<
at
::
Tensor
>
dst_layers
,
at
::
Tensor
&
dst_k
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
&
src_v_opt
,
// Only used when IsMLA is false (for src_v)
const
at
::
Tensor
dst_indices
,
at
::
Tensor
&
dst_v_opt
,
// Only used when IsMLA is false (for dst_v)
int64_t
page_size
)
{
const
at
::
Tensor
&
src_indices
,
TORCH_CHECK
(
const
at
::
Tensor
&
dst_indices
,
src_layers
.
size
()
==
dst_layers
.
size
(),
"Source and destination layers must have the same number of layers"
);
int64_t
page_size
,
int64_t
num_layers
=
1
)
{
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
...
@@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl(
...
@@ -270,73 +452,14 @@ inline void transfer_kv_direct_impl(
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
const
int64_t
num_layers
=
src_layers
.
size
();
for
(
const
auto
i
:
c10
::
irange
(
num_pages
)
)
{
for
(
int64_t
i
=
0
;
i
<
num_pages
;
++
i
)
{
auto
s_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
s
rc
_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d
st
_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
if
constexpr
(
AllLayers
)
{
for
(
int64_t
j
=
0
;
j
<
num_layers
;
++
j
)
{
for
(
const
auto
j
:
c10
::
irange
(
num_layers
))
{
transfer_page_direct
(
src_layers
[
j
],
dst_layers
[
j
],
src_index
,
dst_index
,
page_size
);
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
.
select
(
0
,
j
),
dst_k
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
.
select
(
0
,
j
),
dst_v_opt
.
select
(
0
,
j
),
s_index
,
d_index
,
page_size
);
}
}
}
else
{
// Per-layer
if
constexpr
(
IsMLA
)
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
}
else
{
transfer_page_direct
(
src_k
,
dst_k
,
s_index
,
d_index
,
page_size
);
transfer_page_direct
(
src_v_opt
,
dst_v_opt
,
s_index
,
d_index
,
page_size
);
}
}
}
}
}
}
}
void
transfer_kv_per_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
transfer_kv_direct_impl
<
false
,
false
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_direct
(
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
transfer_kv_direct_impl
<
false
,
true
>
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
void
transfer_kv_per_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
false
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
);
}
void
transfer_kv_all_layer_mla_direct
(
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
num_layers
)
{
at
::
Tensor
empty_tensor
=
at
::
Tensor
();
transfer_kv_direct_impl
<
true
,
true
>
(
src
,
dst
,
empty_tensor
,
empty_tensor
,
src_indices
,
dst_indices
,
page_size
,
num_layers
);
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
b4326330
...
@@ -399,38 +399,42 @@ void transfer_kv_per_layer(
...
@@ -399,38 +399,42 @@ void transfer_kv_per_layer(
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_
direct
(
void
transfer_kv_per_layer_
pf_lf
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
,
at
::
Tensor
dst_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer
(
void
transfer_kv_all_layer
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
_layers
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
dst_k
_layers
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
_layers
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
dst_v
_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_
direct
(
void
transfer_kv_all_layer_
lf_pf
(
const
at
::
Tensor
src_k
,
const
at
::
Tensor
src_k
_layers
,
at
::
Tensor
dst_k
,
at
::
Tensor
dst_k
,
const
at
::
Tensor
src_v
,
const
at
::
Tensor
src_v
_layers
,
at
::
Tensor
dst_v
,
at
::
Tensor
dst_v
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
item_size
,
int64_t
num_layers
);
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla
(
void
transfer_kv_per_layer_mla
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
,
...
@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
...
@@ -441,32 +445,43 @@ void transfer_kv_per_layer_mla(
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_per_layer_mla_
direct
(
void
transfer_kv_per_layer_mla_
pf_lf
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
,
at
::
Tensor
dst
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
int64_t
item_size
,
int64_t
src_layout_dim
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_mla
(
void
transfer_kv_all_layer_mla
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
_layers
,
at
::
Tensor
dst
,
const
at
::
Tensor
dst
_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
item_size
,
int64_t
item_size
,
int64_t
num_layers
,
int64_t
num_layers
,
int64_t
src_layer_offset
,
int64_t
dst_layer_offset
,
int64_t
block_quota
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
int64_t
num_warps_per_block
);
void
transfer_kv_all_layer_mla_
direct
(
void
transfer_kv_all_layer_mla_
lf_pf
(
const
at
::
Tensor
src
,
const
at
::
Tensor
src
_layers
,
at
::
Tensor
dst
,
at
::
Tensor
dst
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
,
int64_t
item_size
,
int64_t
num_layers
);
int64_t
dst_layout_dim
,
int64_t
num_layers
,
int64_t
block_quota
,
int64_t
num_warps_per_block
);
void
transfer_kv_direct
(
const
std
::
vector
<
at
::
Tensor
>&
src_layers
,
std
::
vector
<
at
::
Tensor
>
dst_layers
,
const
at
::
Tensor
src_indices
,
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
/*
/*
* From csrc/moe/cutlass_moe/w4a8
* From csrc/moe/cutlass_moe/w4a8
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
b4326330
from
typing
import
List
import
torch
import
torch
...
@@ -22,36 +24,31 @@ def transfer_kv_per_layer(
...
@@ -22,36 +24,31 @@ def transfer_kv_per_layer(
dst_v
,
dst_v
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
*
src_k
.
element_size
(),
# todo, hot fix for compatibility
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_
per_layer_
direct
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
[
src_k
,
src_v
],
[
dst_k
,
dst_v
]
,
src_indices
,
dst_indices
,
page_size
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_
all
_layer
(
def
transfer_kv_
per
_layer
_pf_lf
(
src_k
:
torch
.
Tensor
,
src_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
src_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
src_layout_dim
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_pf_lf
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k
,
src_k
,
dst_k
,
dst_k
,
src_v
,
src_v
,
...
@@ -59,20 +56,84 @@ def transfer_kv_all_layer(
...
@@ -59,20 +56,84 @@ def transfer_kv_all_layer(
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
,
src_layout_dim
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_all_layer
(
src_k_layers
:
torch
.
Tensor
,
dst_k_layers
:
torch
.
Tensor
,
src_v_layers
:
torch
.
Tensor
,
dst_v_layers
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
item_size
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k_layers
,
dst_k_layers
,
src_v_layers
,
dst_v_layers
,
src_indices
,
dst_indices
,
item_size
,
num_layers
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct
(
raise
NotImplementedError
(
"Deprecated interface"
)
src_k
,
dst_k
,
src_v
,
dst_v
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_lf_pf
(
src_k_layers
:
torch
.
Tensor
,
dst_k
:
torch
.
Tensor
,
src_v_layers
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
dst_layout_dim
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_lf_pf
(
src_k_layers
,
dst_k
,
src_v_layers
,
dst_v
,
src_indices
,
dst_indices
,
item_size
,
dst_layout_dim
,
num_layers
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_direct
(
src_layers
:
List
[
torch
.
Tensor
],
dst_layers
:
List
[
torch
.
Tensor
],
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
src_layers
,
dst_layers
,
src_indices
,
dst_indices
,
page_size
)
def
transfer_kv_per_layer_mla
(
def
transfer_kv_per_layer_mla
(
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
...
@@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla(
...
@@ -90,48 +151,87 @@ def transfer_kv_per_layer_mla(
dst
,
dst
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
*
src
.
element_size
(),
# todo, hot fix for compatibility
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_
per_layer_mla_
direct
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
src
,
dst
,
src_indices
,
dst_indices
,
page_size
[
src
]
,
[
dst
]
,
src_indices
,
dst_indices
,
page_size
)
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_
all
_layer_mla
(
def
transfer_kv_
per
_layer_mla
_pf_lf
(
src
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
src_layout_dim
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla_pf_lf
(
src
,
dst
,
src_indices
,
dst_indices
,
item_size
,
src_layout_dim
,
block_quota
,
num_warps_per_block
,
)
def
transfer_kv_all_layer_mla
(
src_layers
:
torch
.
Tensor
,
dst_layers
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
num_layers
:
int
,
src_layer_offset
:
int
,
dst_layer_offset
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
src
,
src
_layers
,
dst
,
dst
_layers
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
,
num_layers
,
num_layers
,
src_layer_offset
,
dst_layer_offset
,
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_direct
(
raise
NotImplementedError
(
"Deprecated interface"
)
src
,
dst
,
src_indices
,
dst_indices
,
page_size
,
num_layers
)
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_mla_lf_pf
(
src_layers
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
item_size
:
int
,
dst_layout_dim
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla_lf_pf
(
src_layers
,
dst
,
src_indices
,
dst_indices
,
item_size
,
dst_layout_dim
,
num_layers
,
block_quota
,
num_warps_per_block
,
)
sgl-kernel/tests/test_kvcacheio.py
View file @
b4326330
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
sgl_kernel.kvcacheio
import
(
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer
,
transfer_kv_all_layer_mla
,
transfer_kv_all_layer_mla
,
transfer_kv_direct
,
transfer_kv_per_layer
,
transfer_kv_per_layer
,
transfer_kv_per_layer_mla
,
transfer_kv_per_layer_mla
,
)
)
...
@@ -104,14 +105,12 @@ def test_transfer_kv(
...
@@ -104,14 +105,12 @@ def test_transfer_kv(
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
item_size
=
item_size
,
)
)
transfer_kv_
per_layer_mla
(
transfer_kv_
direct
(
src_pool_host
[
layer_idx_to_test
],
[
src_pool_host
[
layer_idx_to_test
]
]
,
dst_pool_direct
[
layer_idx_to_test
],
[
dst_pool_direct
[
layer_idx_to_test
]
]
,
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
)
)
else
:
else
:
for
layer_id
in
range
(
num_layers
):
for
layer_id
in
range
(
num_layers
):
...
@@ -121,29 +120,34 @@ def test_transfer_kv(
...
@@ -121,29 +120,34 @@ def test_transfer_kv(
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
)
)
src_layers_device
=
torch
.
tensor
(
[
src_pool_host
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_layers_device
=
torch
.
tensor
(
[
dst_pool_kernel
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
transfer_kv_all_layer_mla
(
transfer_kv_all_layer_mla
(
src_
pool_host
,
src_
layers_device
,
dst_
pool_kernel
,
dst_
layers_device
,
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
*
dtype
.
itemsize
,
item_size
=
item_size
,
num_layers
=
num_layers
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
transfer_kv_
all_layer_mla
(
transfer_kv_
direct
(
src_pool_host
,
[
src_pool_host
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
dst_pool_direct
,
[
dst_pool_direct
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_pool_kernel
,
dst_pool_ref
)
torch
.
testing
.
assert_close
(
dst_pool_kernel
,
dst_pool_ref
)
...
@@ -173,16 +177,15 @@ def test_transfer_kv(
...
@@ -173,16 +177,15 @@ def test_transfer_kv(
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
item_size
=
item_size
,
)
)
transfer_kv_per_layer
(
transfer_kv_direct
(
src_k_pool
[
layer_idx_to_test
],
[
src_k_pool
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
]],
[
dst_k_pool_direct
[
layer_idx_to_test
],
dst_k_pool_direct
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
],
dst_v_pool_direct
[
layer_idx_to_test
],
dst_v_pool_direct
[
layer_idx_to_test
],
],
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
)
)
else
:
else
:
for
layer_id
in
range
(
num_layers
):
for
layer_id
in
range
(
num_layers
):
...
@@ -198,33 +201,52 @@ def test_transfer_kv(
...
@@ -198,33 +201,52 @@ def test_transfer_kv(
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
)
)
src_k_layers_device
=
torch
.
tensor
(
[
src_k_pool
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)],
dtype
=
torch
.
uint64
,
device
=
device
,
)
src_v_layers_device
=
torch
.
tensor
(
[
src_v_pool
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_k_layers_device
=
torch
.
tensor
(
[
dst_k_pool_kernel
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
dst_v_layers_device
=
torch
.
tensor
(
[
dst_v_pool_kernel
[
layer_id
].
data_ptr
()
for
layer_id
in
range
(
num_layers
)
],
dtype
=
torch
.
uint64
,
device
=
device
,
)
transfer_kv_all_layer
(
transfer_kv_all_layer
(
src_k_
pool
,
src_k_
layers_device
,
dst_k_
pool_kernel
,
dst_k_
layers_device
,
src_v_
pool
,
src_v_
layers_device
,
dst_v_
pool_kernel
,
dst_v_
layers_device
,
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
io_backend
=
"kernel"
,
page_size
=
page_size
,
item_size
=
item_size
*
dtype
.
itemsize
,
item_size
=
item_size
,
num_layers
=
num_layers
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
transfer_kv_
all_layer
(
transfer_kv_
direct
(
src_k_pool
,
[
src_k_pool
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
dst_k_pool_direct
,
+
[
src_v_pool
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
src_v_pool
,
[
dst_k_pool_direct
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
dst_v_pool_direct
,
+
[
dst_v_pool_direct
[
layer_id
]
for
layer_id
in
range
(
num_layers
)]
,
src_indices_host
,
src_indices_host
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"direct"
,
page_size
=
page_size
,
page_size
=
page_size
,
item_size
=
item_size
,
num_layers
=
num_layers
,
src_layer_offset
=
total_items_in_pool
*
item_size
,
dst_layer_offset
=
total_items_in_pool
*
item_size
,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_k_pool_kernel
,
dst_k_pool_ref
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment