Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
144bc70f
Unverified
Commit
144bc70f
authored
Sep 10, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 10, 2024
Browse files
Organize flashinfer indices update (#1378)
parent
46094e0c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
252 additions
and
200 deletions
+252
-200
python/sglang/srt/layers/flashinfer_utils.py
python/sglang/srt/layers/flashinfer_utils.py
+237
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-5
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+8
-192
test/srt/test_create_kvindices.py
test/srt/test_create_kvindices.py
+1
-3
No files found.
python/sglang/srt/layers/flashinfer_utils.py
0 → 100644
View file @
144bc70f
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
max_context_len
,
kv_indices_ptr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
class
FlashinferUpdater
:
def
__init__
(
self
,
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
):
self
.
forward_mode
=
forward_mode
self
.
model_runner
=
model_runner
self
.
req_pool_indices
=
req_pool_indices
self
.
seq_lens
=
seq_lens
self
.
prefix_lens
=
prefix_lens
self
.
flashinfer_use_ragged
=
flashinfer_use_ragged
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
batch_size
=
len
(
req_pool_indices
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
(
self
.
flashinfer_decode_wrapper
,
self
.
flashinfer_prefill_wrapper_ragged
,
self
.
flashinfer_prefill_wrapper_paged
,
)
=
(
flashinfer_decode_wrapper
,
self
.
model_runner
.
flashinfer_prefill_wrapper_ragged
,
self
.
model_runner
.
flashinfer_prefill_wrapper_paged
,
)
# CUDA graph uses different flashinfer_decode_wrapper
if
self
.
flashinfer_decode_wrapper
is
None
:
self
.
flashinfer_decode_wrapper
=
self
.
model_runner
.
flashinfer_decode_wrapper
def
_init_indices_no_window
(
self
):
if
self
.
flashinfer_use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
None
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
self
.
kv_indices
,
)
def
_init_indices_window
(
self
,
wrapper_id
):
# window attention use paged only
if
wrapper_id
==
0
:
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
paged_kernel_lens
=
self
.
seq_lens
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
kv_start_idx
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
self
.
kv_indices
,
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
decode_wrapper
.
end_forward
()
decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
data_type
=
self
.
model_runner
.
kv_cache_dtype
,
q_data_type
=
self
.
model_runner
.
dtype
,
)
def
_update_extend_indices
(
self
,
ragged_wrapper
,
paged_wrapper
):
# extend part
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
-
self
.
prefix_lens
,
dim
=
0
)
if
self
.
flashinfer_use_ragged
:
ragged_wrapper
.
end_forward
()
ragged_wrapper
.
begin_forward
(
qo_indptr
,
qo_indptr
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
)
# cached part
paged_wrapper
.
end_forward
()
paged_wrapper
.
begin_forward
(
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
)
def
update_indices_no_window
(
self
):
self
.
_init_indices_no_window
()
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
flashinfer_decode_wrapper
)
else
:
self
.
_update_extend_indices
(
self
.
flashinfer_prefill_wrapper_ragged
,
self
.
flashinfer_prefill_wrapper_paged
,
)
def
update_indices_window
(
self
):
assert
self
.
flashinfer_use_ragged
is
False
for
wrapper_id
in
range
(
2
):
self
.
_init_indices_window
(
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
flashinfer_decode_wrapper
[
wrapper_id
])
else
:
self
.
_update_extend_indices
(
None
,
self
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
],
)
def
update_flashinfer_indices
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
):
flashinfer_updater
=
FlashinferUpdater
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
if
model_runner
.
sliding_window_size
is
None
:
flashinfer_updater
.
update_indices_no_window
()
else
:
flashinfer_updater
.
update_indices_window
()
python/sglang/srt/managers/schedule_batch.py
View file @
144bc70f
...
...
@@ -349,6 +349,7 @@ class ScheduleBatch:
# For mixed chunekd prefill
prefix_lens_cpu
:
List
[
int
]
=
None
running_bs
:
int
=
None
# For processing logprobs
return_logprob
:
bool
=
False
...
...
@@ -446,6 +447,9 @@ class ScheduleBatch:
self
.
sampling_info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
self
.
forward_mode
=
ForwardMode
.
MIXED
self
.
running_bs
=
running_batch
.
batch_size
()
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
self
.
reqs
]
prefix_lens_cpu
.
extend
(
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
144bc70f
...
...
@@ -25,6 +25,7 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.layers.logits_processor
import
(
LogitsMetadata
,
LogitsProcessor
,
...
...
@@ -32,11 +33,7 @@ from sglang.srt.layers.logits_processor import (
)
from
sglang.srt.layers.sampler
import
SampleOutput
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
(
ForwardMode
,
InputMetadata
,
update_flashinfer_indices
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.utils
import
monkey_patch_vllm_all_gather
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
144bc70f
...
...
@@ -22,8 +22,8 @@ from typing import TYPE_CHECKING, List
import
numpy
as
np
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.flashinfer_utils
import
update_flashinfer_indices
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -39,16 +39,21 @@ class ForwardMode(IntEnum):
EXTEND
=
auto
()
# Decode one token.
DECODE
=
auto
()
# Contains both PREFILL and EXTEND.
MIXED
=
auto
()
def
is_prefill
(
self
):
return
self
==
ForwardMode
.
PREFILL
def
is_extend
(
self
):
return
self
==
ForwardMode
.
EXTEND
return
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
MIXED
def
is_decode
(
self
):
return
self
==
ForwardMode
.
DECODE
def
is_mixed
(
self
):
return
self
==
ForwardMode
.
MIXED
@
dataclass
class
InputMetadata
:
...
...
@@ -270,192 +275,3 @@ class InputMetadata:
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_use_ragged
,
)
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
max_context_len
,
kv_indices_ptr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
def
update_flashinfer_indices
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
if
model_runner
.
sliding_window_size
is
None
:
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
model_runner
.
req_to_token_pool
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
None
,
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
kv_indices
,
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
.
is_decode
():
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
.
end_forward
()
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
data_type
=
model_runner
.
kv_cache_dtype
,
q_data_type
=
model_runner
.
dtype
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# window attention use paged only
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
+
1
)
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
)
+
seq_lens
-
prefix_lens
,
)
else
:
paged_kernel_lens
=
seq_lens
kv_start_idx
=
seq_lens
-
paged_kernel_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
model_runner
.
req_to_token_pool
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
kv_indices
,
)
if
forward_mode
.
is_decode
():
# CUDA graph uses different flashinfer_decode_wrapper
if
flashinfer_decode_wrapper
is
None
:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
flashinfer_decode_wrapper
[
wrapper_id
].
end_forward
()
flashinfer_decode_wrapper
[
wrapper_id
].
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
data_type
=
model_runner
.
kv_cache_dtype
,
q_data_type
=
model_runner
.
dtype
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
test/srt/test_create_kvindices.py
View file @
144bc70f
...
...
@@ -4,9 +4,7 @@ import unittest
import
numpy
as
np
import
torch
from
sglang.srt.model_executor.forward_batch_info
import
(
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.flashinfer_utils
import
create_flashinfer_kv_indices_triton
class
TestCreateKvIndices
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment