Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
396a6924
Unverified
Commit
396a6924
authored
Jul 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 12, 2024
Browse files
Cleanup attention backend: flashinfer and triton (#611)
parent
af4e7910
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
178 additions
and
159 deletions
+178
-159
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+7
-8
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+140
-119
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+30
-30
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
396a6924
"""Radix attention."""
"""Radix attention."""
import
numpy
as
np
import
torch
import
torch
from
flashinfer.cascade
import
merge_state
from
flashinfer.cascade
import
merge_state
from
torch
import
nn
from
torch
import
nn
...
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
...
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
req_pool_indices
,
input_metadata
.
start_loc
,
input_metadata
.
triton_
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
seq_lens
,
input_metadata
.
prefix_lens
,
input_metadata
.
triton_
prefix_lens
,
input_metadata
.
extend_start_loc
,
input_metadata
.
extend_start_loc
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
triton_
max_seq_len
,
input_metadata
.
max_extend_len
,
input_metadata
.
triton_
max_extend_len
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
logit_cap
=
self
.
logit_cap
,
)
)
...
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
...
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
req_pool_indices
,
input_metadata
.
start_loc
,
input_metadata
.
triton_
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
triton_
max_seq_len
,
input_metadata
.
total_num_tokens
,
input_metadata
.
total_num_tokens
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
logit_cap
=
self
.
logit_cap
,
...
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
...
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
logits_soft_cap
=
self
.
logit_cap
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
if
input_metadata
.
no_prefix
:
if
input_metadata
.
extend_
no_prefix
:
o
=
o1
o
=
o1
else
:
else
:
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
...
...
python/sglang/srt/layers/token_attention.py
View file @
396a6924
...
@@ -312,7 +312,7 @@ def token_attention_fwd(
...
@@ -312,7 +312,7 @@ def token_attention_fwd(
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
total_num_tokens
,
total_num_tokens
,
sm_scale
=
None
,
sm_scale
,
logit_cap
=-
1
,
logit_cap
=-
1
,
att_m
=
None
,
att_m
=
None
,
):
):
...
@@ -320,7 +320,6 @@ def token_attention_fwd(
...
@@ -320,7 +320,6 @@ def token_attention_fwd(
att_m
=
torch
.
empty
(
att_m
=
torch
.
empty
(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
if
sm_scale
is
None
else
sm_scale
_token_att_m_fwd
(
_token_att_m_fwd
(
q
,
q
,
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
396a6924
...
@@ -75,6 +75,7 @@ class Req:
...
@@ -75,6 +75,7 @@ class Req:
"""Store all inforamtion of a request."""
"""Store all inforamtion of a request."""
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
# Input and output info
self
.
rid
=
rid
self
.
rid
=
rid
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
...
@@ -97,6 +98,11 @@ class Req:
...
@@ -97,6 +98,11 @@ class Req:
self
.
image_offset
=
0
self
.
image_offset
=
0
self
.
pad_value
=
None
self
.
pad_value
=
None
# Prefix info
self
.
extend_input_len
=
0
self
.
prefix_indices
=
[]
self
.
last_node
=
None
# Sampling parameters
# Sampling parameters
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
stream
=
False
self
.
stream
=
False
...
@@ -105,11 +111,6 @@ class Req:
...
@@ -105,11 +111,6 @@ class Req:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
# Prefix info
self
.
extend_input_len
=
0
self
.
prefix_indices
=
[]
self
.
last_node
=
None
# Logprobs
# Logprobs
self
.
return_logprob
=
False
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
logprob_start_len
=
0
...
@@ -261,35 +262,36 @@ class Req:
...
@@ -261,35 +262,36 @@ class Req:
class
Batch
:
class
Batch
:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch."""
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
TokenToKVPool
tree_cache
:
RadixCache
tree_cache
:
RadixCache
#
b
atched arguments to model runner
#
B
atched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
int
=
None
#
f
or processing logprobs
#
F
or processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
#
f
or multimodal
#
F
or multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
image_offsets
:
List
[
int
]
=
None
#
o
ther arguments for control
#
O
ther arguments for control
output_ids
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
int
=
None
#
b
atched sampling params
#
B
atched sampling params
temperatures
:
torch
.
Tensor
=
None
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
...
@@ -312,8 +314,8 @@ class Batch:
...
@@ -312,8 +314,8 @@ class Batch:
def
is_empty
(
self
):
def
is_empty
(
self
):
return
len
(
self
.
reqs
)
==
0
return
len
(
self
.
reqs
)
==
0
# whether batch has at least 1 streaming request
def
has_stream
(
self
)
->
bool
:
def
has_stream
(
self
)
->
bool
:
# Return whether batch has at least 1 streaming request
return
any
(
r
.
stream
for
r
in
self
.
reqs
)
return
any
(
r
.
stream
for
r
in
self
.
reqs
)
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
...
@@ -347,7 +349,7 @@ class Batch:
...
@@ -347,7 +349,7 @@ class Batch:
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
# Alloc mem
# Alloc
ate
mem
ory
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
...
@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
...
@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
return
probs_sort
,
probs_idx
return
probs_sort
,
probs_idx
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
"""Store all inforamtion of a forward pass."""
...
@@ -711,110 +712,37 @@ class InputMetadata:
...
@@ -711,110 +712,37 @@ class InputMetadata:
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
batch_size
:
int
batch_size
:
int
total_num_tokens
:
int
total_num_tokens
:
int
max_seq_len
:
int
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prefix_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
TokenToKVPool
#
f
or extend
#
F
or extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
max_
extend_
len
:
int
=
0
extend_
no_prefix
:
bool
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
int
=
None
# Output options
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
# for flashinfer
# Trition attention backend
qo_indptr
:
torch
.
Tensor
=
None
triton_max_seq_len
:
int
=
0
kv_indptr
:
torch
.
Tensor
=
None
triton_max_extend_len
:
int
=
0
kv_indices
:
torch
.
Tensor
=
None
triton_start_loc
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
triton_prefix_lens
:
torch
.
Tensor
=
None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
self
.
seq_lens
else
:
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
@
classmethod
def
create
(
def
create
(
cls
,
cls
,
...
@@ -830,14 +758,20 @@ class InputMetadata:
...
@@ -830,14 +758,20 @@ class InputMetadata:
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
):
):
if
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
)
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
extend_seq_lens
=
extend_start_loc
=
extend_no_prefix
=
None
if
not
model_runner
.
server_args
.
disable_flashinfer
:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens
=
None
else
:
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
else
:
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
...
@@ -855,22 +789,27 @@ class InputMetadata:
...
@@ -855,22 +789,27 @@ class InputMetadata:
),
),
device
=
"cuda"
,
device
=
"cuda"
,
)
)
extend_seq_lens
=
seq_lens
-
prefix_lens
extend_start_loc
=
torch
.
zeros_like
(
seq_lens
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
extend_no_prefix
=
torch
.
all
(
prefix_lens
==
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
total_num_tokens
=
total_num_tokens
,
max_seq_len
=
max_seq_len
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
start_loc
=
start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
positions
=
positions
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
out_cache_cont_end
=
out_cache_cont_end
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
return_logprob
=
return_logprob
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
...
@@ -878,14 +817,96 @@ class InputMetadata:
...
@@ -878,14 +817,96 @@ class InputMetadata:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
)
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
if
model_runner
.
server_args
.
disable_flashinfer
:
ret
.
init_extend_args
()
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
triton_start_loc
,
ret
.
init_flashinfer_args
(
ret
.
triton_prefix_lens
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
return
ret
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
):
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
if
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
seq_lens
else
:
paged_kernel_lens
=
prefix_lens
kv_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
model_runner
.
flashinfer_decode_wrapper
.
end_forward
()
model_runner
.
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
def
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
):
batch_size
=
len
(
seq_lens
)
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
if
forward_mode
==
ForwardMode
.
DECODE
:
max_extend_len
=
None
else
:
extend_seq_lens
=
seq_lens
-
prefix_lens
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
python/sglang/srt/managers/controller/model_runner.py
View file @
396a6924
...
@@ -182,39 +182,39 @@ class ModelRunner:
...
@@ -182,39 +182,39 @@ class ModelRunner:
return
c
return
c
def
init_flash_infer
(
self
):
def
init_flash_infer
(
self
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
self
.
server_args
.
disable_flashinfer
:
from
flashinfer
import
(
self
.
flashinfer_prefill_wrapper_ragged
=
None
BatchDecodeWithPagedKVCacheWrapper
,
self
.
flashinfer_prefill_wrapper_paged
=
None
BatchPrefillWithPagedKVCacheWrapper
,
self
.
flashinfer_decode_wrapper
=
None
BatchPrefillWithRaggedKVCacheWrapper
,
return
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
from
flashinfer
import
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
BatchDecodeWithPagedKVCacheWrapper
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
BatchPrefillWithPagedKVCacheWrapper
,
):
BatchPrefillWithRaggedKVCacheWrapper
,
use_tensor_cores
=
True
)
else
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
use_tensor_cores
=
False
workspace_buffers
=
torch
.
empty
(
if
not
_grouped_size_compiled_for_decode_kernels
(
2
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
)
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
self
.
flashinfer_prefill_wrapper_ragged
=
(
):
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
use_tensor_cores
=
True
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
else
:
else
:
self
.
flashinfer_prefill_wrapper_ragged
=
(
use_tensor_cores
=
False
self
.
flashinfer_prefill_wrapper_paged
)
=
None
workspace_buffers
=
torch
.
empty
(
self
.
flashinfer_decode_wrapper
=
None
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_extend
(
self
,
batch
:
Batch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment