Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d134c139
Unverified
Commit
d134c139
authored
Aug 31, 2024
by
xiaobochen
Committed by
GitHub
Aug 31, 2024
Browse files
Optimize the update flashinfer indices (#1262)
parent
6cc9c525
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
137 additions
and
23 deletions
+137
-23
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+61
-23
test/srt/test_create_kvindices.py
test/srt/test_create_kvindices.py
+76
-0
No files found.
python/sglang/srt/model_executor/forward_batch_info.py
View file @
d134c139
...
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List
...
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
...
@@ -262,6 +264,42 @@ class InputMetadata:
...
@@ -262,6 +264,42 @@ class InputMetadata:
)
)
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
max_context_len
,
kv_indices_ptr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
def
update_flashinfer_indices
(
def
update_flashinfer_indices
(
forward_mode
,
forward_mode
,
model_runner
,
model_runner
,
...
@@ -285,17 +323,18 @@ def update_flashinfer_indices(
...
@@ -285,17 +323,18 @@ def update_flashinfer_indices(
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
cat
(
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
[
model_runner
.
req_to_token_pool
.
req_to_token
,
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
,
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
paged_kernel_lens
,
]
kv_indptr
,
for
i
in
range
(
batch_size
)
None
,
],
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
dim
=
0
,
kv_indices
,
).
contiguous
()
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
...
@@ -365,18 +404,17 @@ def update_flashinfer_indices(
...
@@ -365,18 +404,17 @@ def update_flashinfer_indices(
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
cat
(
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
[
model_runner
.
req_to_token_pool
.
req_to_token
,
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
,
req_pool_indices_cpu
[
i
],
paged_kernel_lens
,
kv_start_idx
[
i
]
:
kv_start_idx
[
i
]
+
paged_kernel_lens_cpu
[
i
],
kv_indptr
,
]
kv_start_idx
,
for
i
in
range
(
batch_size
)
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
],
kv_indices
,
dim
=
0
,
)
).
contiguous
()
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
# CUDA graph uses different flashinfer_decode_wrapper
# CUDA graph uses different flashinfer_decode_wrapper
...
...
test/srt/test_create_kvindices.py
0 → 100644
View file @
d134c139
import
itertools
import
unittest
import
numpy
as
np
import
torch
from
sglang.srt.model_executor.forward_batch_info
import
(
create_flashinfer_kv_indices_triton
,
)
class
TestCreateKvIndices
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
not
torch
.
cuda
.
is_available
():
raise
unittest
.
SkipTest
(
"CUDA is not available"
)
torch
.
set_default_device
(
"cuda"
)
def
_run_test
(
self
,
batch
,
max_batch
,
max_context_len
):
req_to_token
=
torch
.
arange
(
max_batch
*
max_context_len
,
dtype
=
torch
.
int32
,
device
=
"cuda"
).
reshape
((
max_batch
,
max_context_len
))
req_pool_indices
=
torch
.
tensor
(
torch
.
from_numpy
(
np
.
random
.
choice
(
range
(
max_batch
),
size
=
batch
,
replace
=
False
)
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
paged_kernel_lens
=
torch
.
tensor
(
torch
.
from_numpy
(
np
.
random
.
choice
(
range
(
max_context_len
),
size
=
batch
,
replace
=
False
)
),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
kv_indptr
=
torch
.
zeros
((
batch
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
# ref
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices_ref
=
torch
.
cat
(
[
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]]
for
i
in
range
(
batch
)
],
dim
=
0
,
).
contiguous
()
# triton
kv_indices_triton
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
batch
,)](
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
None
,
req_to_token
.
size
(
1
),
kv_indices_triton
,
)
# Check
self
.
assertTrue
(
torch
.
equal
(
kv_indices_ref
,
kv_indices_triton
))
def
test_create_kvindices
(
self
):
BATCH
=
[
1
,
37
,
1786
]
MAX_BATCH
=
4096
MAX_CONTEXT_LEN
=
4096
for
batch
in
BATCH
:
self
.
_run_test
(
batch
,
MAX_BATCH
,
MAX_CONTEXT_LEN
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment