Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5be8c2f7
Unverified
Commit
5be8c2f7
authored
Sep 10, 2025
by
huangtingwei
Committed by
GitHub
Sep 10, 2025
Browse files
Page first direct IO kernel (#10060)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
737d73ed
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
358 additions
and
2 deletions
+358
-2
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+8
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+80
-2
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+15
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+25
-0
sgl-kernel/tests/test_kvcacheio.py
sgl-kernel/tests/test_kvcacheio.py
+230
-0
No files found.
sgl-kernel/csrc/common_extension.cc
View file @
5be8c2f7
...
...
@@ -331,6 +331,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
"page_size) -> ()"
);
m
.
impl
(
"transfer_kv_direct"
,
torch
::
kCUDA
,
&
transfer_kv_direct
);
m
.
def
(
"transfer_kv_per_layer_direct_pf_lf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int layer_id, int page_size)->() "
);
m
.
impl
(
"transfer_kv_per_layer_direct_pf_lf"
,
torch
::
kCUDA
,
&
transfer_kv_per_layer_direct_pf_lf
);
m
.
def
(
"transfer_kv_all_layer_direct_lf_pf(Tensor[] src_ptrs, Tensor[] dst_ptrs, Tensor src_indices, "
"Tensor dst_indices, int page_size) ->() "
);
m
.
impl
(
"transfer_kv_all_layer_direct_lf_pf"
,
torch
::
kCUDA
,
&
transfer_kv_all_layer_direct_lf_pf
);
/*
* From csrc/memory
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
5be8c2f7
...
...
@@ -437,8 +437,8 @@ void transfer_kv_all_layer_mla_lf_pf(
}
inline
void
transfer_page_direct
(
const
at
::
Tensor
&
src_buffer
,
at
::
Tensor
&
dst_buffer
,
const
at
::
Tensor
src_buffer
,
at
::
Tensor
dst_buffer
,
int64_t
src_page_index
,
int64_t
dst_page_index
,
int64_t
page_size
)
{
...
...
@@ -493,3 +493,81 @@ void transfer_kv_direct(
start_index
=
end_index
;
}
}
template
<
bool
IsLf2Pf
>
inline
void
transfer_kv_page_first_direct_impl
(
const
std
::
vector
<
at
::
Tensor
>&
src_ptrs
,
std
::
vector
<
at
::
Tensor
>
dst_ptrs
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
start_layer_id
,
int64_t
page_size
)
{
TORCH_CHECK
(
src_indices
.
numel
()
==
dst_indices
.
numel
(),
"Source and destination indices must have the same length"
);
TORCH_CHECK
(
page_size
>
0
,
"Page size must be positive"
);
TORCH_CHECK
(
src_indices
.
numel
()
%
page_size
==
0
,
"Source indices size must be divisible by page size"
);
auto
src_indices_cpu
=
src_indices
.
cpu
();
auto
dst_indices_cpu
=
dst_indices
.
cpu
();
const
int64_t
num_pages
=
src_indices_cpu
.
size
(
0
)
/
page_size
;
if
constexpr
(
IsLf2Pf
)
{
const
bool
is_mla
=
dst_ptrs
.
size
()
==
1
;
const
int64_t
num_layers
=
is_mla
?
src_ptrs
.
size
()
:
src_ptrs
.
size
()
/
2
;
for
(
const
auto
i
:
c10
::
irange
(
num_pages
))
{
auto
s_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
auto
d_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
()
/
page_size
;
for
(
int64_t
j
=
0
;
j
<
num_layers
;
++
j
)
{
transfer_page_direct
(
src_ptrs
[
j
],
dst_ptrs
[
0
].
select
(
0
,
d_index
).
select
(
0
,
start_layer_id
+
j
),
s_index
,
0
,
page_size
);
if
(
!
is_mla
)
{
transfer_page_direct
(
src_ptrs
[
j
+
num_layers
],
dst_ptrs
[
1
].
select
(
0
,
d_index
).
select
(
0
,
start_layer_id
+
j
),
s_index
,
0
,
page_size
);
}
}
}
}
else
{
const
bool
is_mla
=
src_ptrs
.
size
()
==
1
;
const
int64_t
num_layers
=
is_mla
?
dst_ptrs
.
size
()
:
dst_ptrs
.
size
()
/
2
;
for
(
const
auto
i
:
c10
::
irange
(
num_pages
))
{
auto
s_index
=
src_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
()
/
page_size
;
auto
d_index
=
dst_indices_cpu
[
i
*
page_size
].
item
<
int64_t
>
();
for
(
int64_t
j
=
0
;
j
<
num_layers
;
++
j
)
{
transfer_page_direct
(
src_ptrs
[
0
].
select
(
0
,
s_index
).
select
(
0
,
start_layer_id
+
j
),
dst_ptrs
[
j
],
0
,
d_index
,
page_size
);
if
(
!
is_mla
)
{
transfer_page_direct
(
src_ptrs
[
1
].
select
(
0
,
s_index
).
select
(
0
,
start_layer_id
+
j
),
dst_ptrs
[
j
+
num_layers
],
0
,
d_index
,
page_size
);
}
}
}
}
}
void
transfer_kv_per_layer_direct_pf_lf
(
const
std
::
vector
<
at
::
Tensor
>&
src_ptrs
,
std
::
vector
<
at
::
Tensor
>
dst_ptrs
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
layer_id
,
int64_t
page_size
)
{
transfer_kv_page_first_direct_impl
<
false
>
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
layer_id
,
page_size
);
}
void
transfer_kv_all_layer_direct_lf_pf
(
const
std
::
vector
<
at
::
Tensor
>&
src_ptrs
,
std
::
vector
<
at
::
Tensor
>
dst_ptrs
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
page_size
)
{
transfer_kv_page_first_direct_impl
<
true
>
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
0
,
page_size
);
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
5be8c2f7
...
...
@@ -569,6 +569,21 @@ void transfer_kv_direct(
const
at
::
Tensor
dst_indices
,
int64_t
page_size
);
void
transfer_kv_per_layer_direct_pf_lf
(
const
std
::
vector
<
at
::
Tensor
>&
src_ptrs
,
std
::
vector
<
at
::
Tensor
>
dst_ptrs
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
layer_id
,
int64_t
page_size
);
void
transfer_kv_all_layer_direct_lf_pf
(
const
std
::
vector
<
at
::
Tensor
>&
src_ptrs
,
std
::
vector
<
at
::
Tensor
>
dst_ptrs
,
const
at
::
Tensor
&
src_indices
,
const
at
::
Tensor
&
dst_indices
,
int64_t
page_size
);
/*
* From FlashInfer
*/
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
5be8c2f7
...
...
@@ -128,6 +128,31 @@ def transfer_kv_direct(
)
def
transfer_kv_per_layer_direct_pf_lf
(
src_ptrs
:
List
[
torch
.
Tensor
],
dst_ptrs
:
List
[
torch
.
Tensor
],
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
layer_id
:
int
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_per_layer_direct_pf_lf
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
layer_id
,
page_size
)
def
transfer_kv_all_layer_direct_lf_pf
(
src_ptrs
:
List
[
torch
.
Tensor
],
dst_ptrs
:
List
[
torch
.
Tensor
],
src_indices
:
torch
.
Tensor
,
dst_indices
:
torch
.
Tensor
,
page_size
:
int
,
):
torch
.
ops
.
sgl_kernel
.
transfer_kv_all_layer_direct_lf_pf
(
src_ptrs
,
dst_ptrs
,
src_indices
,
dst_indices
,
page_size
)
def
transfer_kv_per_layer_mla
(
src
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
...
...
sgl-kernel/tests/test_kvcacheio.py
View file @
5be8c2f7
...
...
@@ -2,9 +2,11 @@ import pytest
import
torch
from
sgl_kernel.kvcacheio
import
(
transfer_kv_all_layer
,
transfer_kv_all_layer_direct_lf_pf
,
transfer_kv_all_layer_mla
,
transfer_kv_direct
,
transfer_kv_per_layer
,
transfer_kv_per_layer_direct_pf_lf
,
transfer_kv_per_layer_mla
,
)
...
...
@@ -13,6 +15,21 @@ def ref_copy_with_indices(src_pool, dst_pool, src_indices, dst_indices):
dst_pool
[
dst_indices
]
=
src_pool
[
src_indices
].
to
(
dst_pool
.
device
)
def
ref_copy_with_indices_pf_direct
(
src_pool
,
dst_pool
,
src_indices
,
dst_indices
,
page_size
,
layer_id
,
lf_to_pf
=
False
):
if
lf_to_pf
:
for
i
in
range
(
0
,
len
(
src_indices
),
page_size
):
dst_pool
[
dst_indices
[
i
]
//
page_size
][
layer_id
]
=
src_pool
[
layer_id
][
src_indices
[
i
:
i
+
page_size
]
].
to
(
dst_pool
.
device
)
else
:
for
i
in
range
(
0
,
len
(
src_indices
),
page_size
):
dst_pool
[
layer_id
][
dst_indices
[
i
:
i
+
page_size
]]
=
src_pool
[
src_indices
[
i
]
//
page_size
][
layer_id
].
to
(
dst_pool
.
device
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_items_to_transfer"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
1
,
16
,
64
])
...
...
@@ -251,5 +268,218 @@ def test_transfer_kv(
torch
.
set_default_dtype
(
original_dtype
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"num_items_to_transfer"
,
[
128
,
1024
,
8192
])
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
16
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"item_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"total_items_in_pool"
,
[
20480
])
@
pytest
.
mark
.
parametrize
(
"is_mla"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"lf_to_pf"
,
[
False
,
True
])
def
test_transfer_kv_pf_direct
(
dtype
:
torch
.
dtype
,
num_items_to_transfer
:
int
,
item_size
:
int
,
page_size
:
int
,
total_items_in_pool
:
int
,
is_mla
:
bool
,
lf_to_pf
:
bool
,
):
original_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
device
=
"cuda"
torch
.
cuda
.
manual_seed
(
42
)
num_layers
=
4
total_pages_in_pool
=
total_items_in_pool
//
page_size
num_pages_to_transfer
=
num_items_to_transfer
//
page_size
if
num_pages_to_transfer
==
0
:
torch
.
set_default_dtype
(
original_dtype
)
return
page_indices
=
torch
.
randperm
(
total_pages_in_pool
,
dtype
=
torch
.
int64
)
src_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[:
num_pages_to_transfer
]
]
)
src_indices_device
=
src_indices_host
.
to
(
device
)
dst_indices_host
=
torch
.
cat
(
[
torch
.
arange
(
p
*
page_size
,
(
p
+
1
)
*
page_size
)
for
p
in
page_indices
[
num_pages_to_transfer
:
2
*
num_pages_to_transfer
]
]
)
dst_indices_device
=
dst_indices_host
.
to
(
device
)
# We will test the per-layer function on the first layer (index 0) of the pool.
layer_idx_to_test
=
0
if
lf_to_pf
:
if
is_mla
:
src_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
to
(
device
)
src_pool_ptrs
=
[
src_pool
[
i
]
for
i
in
range
(
num_layers
)]
dst_pool_ref
=
torch
.
zeros
(
total_pages_in_pool
,
num_layers
,
page_size
,
item_size
).
pin_memory
()
dst_pool_direct
=
torch
.
zeros_like
(
dst_pool_ref
)
torch
.
cuda
.
synchronize
()
transfer_kv_all_layer_direct_lf_pf
(
src_pool_ptrs
,
[
dst_pool_direct
],
src_indices_host
,
dst_indices_host
,
page_size
,
)
for
i
in
range
(
num_layers
):
ref_copy_with_indices_pf_direct
(
src_pool
,
dst_pool_ref
,
src_indices_device
,
dst_indices_host
,
page_size
,
i
,
lf_to_pf
=
True
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_pool_direct
,
dst_pool_ref
)
else
:
src_k_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
to
(
device
)
src_k_pool_ptrs
=
[
src_k_pool
[
i
]
for
i
in
range
(
num_layers
)]
src_v_pool
=
torch
.
randn
(
num_layers
,
total_items_in_pool
,
item_size
).
to
(
device
)
src_v_pool_ptrs
=
[
src_v_pool
[
i
]
for
i
in
range
(
num_layers
)]
dst_k_pool_ref
=
torch
.
zeros
(
total_pages_in_pool
,
num_layers
,
page_size
,
item_size
).
pin_memory
()
dst_v_pool_ref
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_k_pool_direct
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_direct
=
torch
.
zeros_like
(
dst_v_pool_ref
)
torch
.
cuda
.
synchronize
()
transfer_kv_all_layer_direct_lf_pf
(
src_k_pool_ptrs
+
src_v_pool_ptrs
,
[
dst_k_pool_direct
,
dst_v_pool_direct
],
src_indices_host
,
dst_indices_host
,
page_size
,
)
for
i
in
range
(
num_layers
):
ref_copy_with_indices_pf_direct
(
src_k_pool
,
dst_k_pool_ref
,
src_indices_device
,
dst_indices_host
,
page_size
,
i
,
lf_to_pf
=
True
,
)
ref_copy_with_indices_pf_direct
(
src_v_pool
,
dst_v_pool_ref
,
src_indices_device
,
dst_indices_host
,
page_size
,
i
,
lf_to_pf
=
True
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_direct
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_direct
,
dst_v_pool_ref
)
else
:
if
is_mla
:
src_pool
=
torch
.
randn
(
total_pages_in_pool
,
num_layers
,
page_size
,
item_size
).
pin_memory
()
dst_pool_ref
=
torch
.
zeros
(
num_layers
,
total_items_in_pool
,
item_size
).
to
(
device
)
dst_pool_direct
=
torch
.
zeros_like
(
dst_pool_ref
)
dst_pool_direct_ptrs
=
[
dst_pool_direct
[
i
]
for
i
in
range
(
num_layers
)]
torch
.
cuda
.
synchronize
()
transfer_kv_per_layer_direct_pf_lf
(
[
src_pool
],
[
dst_pool_direct_ptrs
[
layer_idx_to_test
]],
src_indices_host
,
dst_indices_host
,
layer_idx_to_test
,
page_size
,
)
ref_copy_with_indices_pf_direct
(
src_pool
,
dst_pool_ref
,
src_indices_host
,
dst_indices_device
,
page_size
,
layer_idx_to_test
,
lf_to_pf
=
False
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_pool_direct
,
dst_pool_ref
)
else
:
src_k_pool
=
torch
.
randn
(
total_pages_in_pool
,
num_layers
,
page_size
,
item_size
).
pin_memory
()
src_v_pool
=
torch
.
randn
(
total_pages_in_pool
,
num_layers
,
page_size
,
item_size
).
pin_memory
()
dst_k_pool_ref
=
torch
.
zeros
(
num_layers
,
total_items_in_pool
,
item_size
).
to
(
device
)
dst_k_pool_direct
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_k_pool_direct_ptrs
=
[
dst_k_pool_direct
[
i
]
for
i
in
range
(
num_layers
)]
dst_v_pool_ref
=
torch
.
zeros_like
(
dst_k_pool_ref
)
dst_v_pool_direct
=
torch
.
zeros_like
(
dst_v_pool_ref
)
dst_v_pool_direct_ptrs
=
[
dst_v_pool_direct
[
i
]
for
i
in
range
(
num_layers
)]
torch
.
cuda
.
synchronize
()
transfer_kv_per_layer_direct_pf_lf
(
[
src_k_pool
,
src_v_pool
],
[
dst_k_pool_direct_ptrs
[
layer_idx_to_test
],
dst_v_pool_direct_ptrs
[
layer_idx_to_test
],
],
src_indices_host
,
dst_indices_host
,
layer_idx_to_test
,
page_size
,
)
ref_copy_with_indices_pf_direct
(
src_k_pool
,
dst_k_pool_ref
,
src_indices_host
,
dst_indices_device
,
page_size
,
layer_idx_to_test
,
lf_to_pf
=
False
,
)
ref_copy_with_indices_pf_direct
(
src_v_pool
,
dst_v_pool_ref
,
src_indices_host
,
dst_indices_device
,
page_size
,
layer_idx_to_test
,
lf_to_pf
=
False
,
)
torch
.
cuda
.
synchronize
()
torch
.
testing
.
assert_close
(
dst_k_pool_direct
,
dst_k_pool_ref
)
torch
.
testing
.
assert_close
(
dst_v_pool_direct
,
dst_v_pool_ref
)
torch
.
set_default_dtype
(
original_dtype
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment