Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d40846d4
Unverified
Commit
d40846d4
authored
Jul 24, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 25, 2025
Browse files
breakdown kernel update (#8334)
parent
145482f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
80 deletions
+44
-80
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+42
-72
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+2
-8
No files found.
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
d40846d4
...
@@ -10,30 +10,21 @@ def transfer_kv_per_layer(
...
@@ -10,30 +10,21 @@ def transfer_kv_per_layer(
dst_v
:
torch
.
Tensor
,
dst_v
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
item_size
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer
(
src_k
,
src_k
,
dst_k
,
dst_k
,
src_v
,
src_v
,
dst_v
,
dst_v
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
*
src_k
.
element_size
(),
# todo, hot fix for compatibility
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
[
src_k
,
src_v
],
[
dst_k
,
dst_v
],
src_indices
,
dst_indices
,
page_size
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_per_layer_pf_lf
(
def
transfer_kv_per_layer_pf_lf
(
...
@@ -69,29 +60,23 @@ def transfer_kv_all_layer(
...
@@ -69,29 +60,23 @@ def transfer_kv_all_layer(
dst_v_layers
:
torch
.
Tensor
,
dst_v_layers
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
item_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer
(
src_k_layers
,
src_k_layers
,
dst_k_layers
,
dst_k_layers
,
src_v_layers
,
src_v_layers
,
dst_v_layers
,
dst_v_layers
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
,
num_layers
,
num_layers
,
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
raise
NotImplementedError
(
"Deprecated interface"
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_lf_pf
(
def
transfer_kv_all_layer_lf_pf
(
...
@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla(
...
@@ -139,28 +124,19 @@ def transfer_kv_per_layer_mla(
dst
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
page_size
:
int
,
item_size
:
int
,
item_size
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_mla
(
src
,
src
,
dst
,
dst
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
*
src
.
element_size
(),
# todo, hot fix for compatibility
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_direct
(
[
src
],
[
dst
],
src_indices
,
dst_indices
,
page_size
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_per_layer_mla_pf_lf
(
def
transfer_kv_per_layer_mla_pf_lf
(
...
@@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla(
...
@@ -190,27 +166,21 @@ def transfer_kv_all_layer_mla(
dst_layers
:
torch
.
Tensor
,
dst_layers
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
io_backend
:
str
,
item_size
:
int
,
item_size
:
int
,
num_layers
:
int
,
num_layers
:
int
,
block_quota
:
int
=
2
,
block_quota
:
int
=
2
,
num_warps_per_block
:
int
=
32
,
num_warps_per_block
:
int
=
32
,
):
):
if
io_backend
==
"kernel"
:
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_mla
(
src_layers
,
src_layers
,
dst_layers
,
dst_layers
,
src_indices
,
src_indices
,
dst_indices
,
dst_indices
,
item_size
,
item_size
,
num_layers
,
num_layers
,
block_quota
,
block_quota
,
num_warps_per_block
,
num_warps_per_block
,
)
)
elif
io_backend
==
"direct"
:
raise
NotImplementedError
(
"Deprecated interface"
)
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
transfer_kv_all_layer_mla_lf_pf
(
def
transfer_kv_all_layer_mla_lf_pf
(
...
...
sgl-kernel/tests/test_kvcacheio.py
View file @
d40846d4
...
@@ -101,9 +101,7 @@ def test_transfer_kv(
...
@@ -101,9 +101,7 @@ def test_transfer_kv(
dst_pool_kernel
[
layer_idx_to_test
],
dst_pool_kernel
[
layer_idx_to_test
],
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
item_size
=
item_size
*
dtype
.
itemsize
,
page_size
=
page_size
,
item_size
=
item_size
,
)
)
transfer_kv_direct
(
transfer_kv_direct
(
[
src_pool_host
[
layer_idx_to_test
]],
[
src_pool_host
[
layer_idx_to_test
]],
...
@@ -138,7 +136,6 @@ def test_transfer_kv(
...
@@ -138,7 +136,6 @@ def test_transfer_kv(
dst_layers_device
,
dst_layers_device
,
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
item_size
=
item_size
*
dtype
.
itemsize
,
item_size
=
item_size
*
dtype
.
itemsize
,
num_layers
=
num_layers
,
num_layers
=
num_layers
,
)
)
...
@@ -173,9 +170,7 @@ def test_transfer_kv(
...
@@ -173,9 +170,7 @@ def test_transfer_kv(
dst_v_pool_kernel
[
layer_idx_to_test
],
dst_v_pool_kernel
[
layer_idx_to_test
],
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
item_size
=
item_size
*
dtype
.
itemsize
,
page_size
=
page_size
,
item_size
=
item_size
,
)
)
transfer_kv_direct
(
transfer_kv_direct
(
[
src_k_pool
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
]],
[
src_k_pool
[
layer_idx_to_test
],
src_v_pool
[
layer_idx_to_test
]],
...
@@ -235,7 +230,6 @@ def test_transfer_kv(
...
@@ -235,7 +230,6 @@ def test_transfer_kv(
dst_v_layers_device
,
dst_v_layers_device
,
src_indices_device
,
src_indices_device
,
dst_indices_device
,
dst_indices_device
,
io_backend
=
"kernel"
,
item_size
=
item_size
*
dtype
.
itemsize
,
item_size
=
item_size
*
dtype
.
itemsize
,
num_layers
=
num_layers
,
num_layers
=
num_layers
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment