Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
af4e7910
Unverified
Commit
af4e7910
authored
Jul 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 12, 2024
Browse files
Clean up the usage of flashinfer (#610)
parent
519e20cf
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
75 deletions
+46
-75
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+3
-12
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-7
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+36
-48
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+0
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
af4e7910
...
...
@@ -31,21 +31,13 @@ class RadixAttention(nn.Module):
self
.
layer_id
=
layer_id
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
extend_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
else
:
self
.
prefill_forward
=
self
.
prefill_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
else
0
def
prefill_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise
NotImplementedError
()
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
...
...
@@ -86,7 +78,6 @@ class RadixAttention(nn.Module):
input_metadata
.
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
other_kv_index
,
input_metadata
.
total_num_tokens
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
...
...
@@ -94,7 +85,7 @@ class RadixAttention(nn.Module):
return
o
def
prefill
_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend
_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o1
,
s1
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
...
...
python/sglang/srt/layers/token_attention.py
View file @
af4e7910
...
...
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
stride_obs
,
stride_oh
,
stride_req_to_token_b
,
other_kv_index
,
# To fix a NAN issue
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
...
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
other_kv_index
,
other
=
0
,
)
qk
=
tl
.
load
(
...
...
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
b_req_idx
,
b_start_loc
,
b_seq_len
,
other_kv_index
,
):
BLOCK
=
64
batch
,
head
=
b_seq_len
.
shape
[
0
],
logics
.
shape
[
0
]
...
...
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
other_kv_index
,
)
return
...
...
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
o
.
stride
(
0
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
other_kv_index
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
v_buffer
.
shape
[
-
1
],
BLOCK_N
=
BLOCK
,
...
...
@@ -315,7 +311,6 @@ def token_attention_fwd(
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
other_kv_index
,
total_num_tokens
,
sm_scale
=
None
,
logit_cap
=-
1
,
...
...
@@ -347,5 +342,4 @@ def token_attention_fwd(
b_req_idx
,
b_start_loc
,
b_seq_len
,
other_kv_index
,
)
python/sglang/srt/managers/controller/infer_batch.py
View file @
af4e7910
...
...
@@ -729,7 +729,6 @@ class InputMetadata:
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
...
...
@@ -743,24 +742,19 @@ class InputMetadata:
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
self
.
forward_mode
==
ForwardMode
.
EXTEND
)
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
self
.
seq_lens
else
:
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
...
...
@@ -769,18 +763,34 @@ class InputMetadata:
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
else
:
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
.
clone
()
,
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
...
...
@@ -789,28 +799,15 @@ class InputMetadata:
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
...
...
@@ -822,7 +819,6 @@ class InputMetadata:
def
create
(
cls
,
model_runner
,
tp_size
,
forward_mode
,
req_pool_indices
,
seq_lens
,
...
...
@@ -833,9 +829,6 @@ class InputMetadata:
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -845,9 +838,6 @@ class InputMetadata:
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
other_kv_index
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
[
0
],
seq_lens
[
0
]
-
1
].
item
()
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
...
...
@@ -865,7 +855,6 @@ class InputMetadata:
),
device
=
"cuda"
,
)
other_kv_index
=
None
ret
=
cls
(
forward_mode
=
forward_mode
,
...
...
@@ -882,12 +871,11 @@ class InputMetadata:
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
...
...
@@ -895,8 +883,8 @@ class InputMetadata:
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
af4e7910
...
...
@@ -221,7 +221,6 @@ class ModelRunner:
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
...
...
@@ -229,9 +228,6 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -242,7 +238,6 @@ class ModelRunner:
input_metadata
=
InputMetadata
.
create
(
self
,
forward_mode
=
ForwardMode
.
DECODE
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
...
...
@@ -252,9 +247,6 @@ class ModelRunner:
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
python/sglang/srt/server_args.py
View file @
af4e7910
...
...
@@ -53,6 +53,7 @@ class ServerArgs:
disable_flashinfer
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
enable_p2p_check
:
bool
=
False
...
...
@@ -294,6 +295,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable regex jump-forward"
,
)
parser
.
add_argument
(
"--disable-cuda-graph"
,
action
=
"store_true"
,
help
=
"Disable cuda graph."
,
)
parser
.
add_argument
(
"--disable-disk-cache"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment