Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
396a6924
"tensoradapter/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9a00cf194fcf994b2527cd927d691144f5e9c47b"
Unverified
Commit
396a6924
authored
Jul 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 12, 2024
Browse files
Cleanup attention backend: flashinfer and triton (#611)
parent
af4e7910
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
178 additions
and
159 deletions
+178
-159
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+7
-8
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+140
-119
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+30
-30
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
396a6924
"""Radix attention."""
"""Radix attention."""
import
numpy
as
np
import
torch
import
torch
from
flashinfer.cascade
import
merge_state
from
flashinfer.cascade
import
merge_state
from
torch
import
nn
from
torch
import
nn
...
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
...
@@ -51,13 +50,13 @@ class RadixAttention(nn.Module):
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
req_pool_indices
,
input_metadata
.
start_loc
,
input_metadata
.
triton_
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
seq_lens
,
input_metadata
.
prefix_lens
,
input_metadata
.
triton_
prefix_lens
,
input_metadata
.
extend_start_loc
,
input_metadata
.
extend_start_loc
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
triton_
max_seq_len
,
input_metadata
.
max_extend_len
,
input_metadata
.
triton_
max_extend_len
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
logit_cap
=
self
.
logit_cap
,
)
)
...
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
...
@@ -75,9 +74,9 @@ class RadixAttention(nn.Module):
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
req_pool_indices
,
input_metadata
.
start_loc
,
input_metadata
.
triton_
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
triton_
max_seq_len
,
input_metadata
.
total_num_tokens
,
input_metadata
.
total_num_tokens
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
logit_cap
=
self
.
logit_cap
,
...
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
...
@@ -95,7 +94,7 @@ class RadixAttention(nn.Module):
logits_soft_cap
=
self
.
logit_cap
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
if
input_metadata
.
no_prefix
:
if
input_metadata
.
extend_
no_prefix
:
o
=
o1
o
=
o1
else
:
else
:
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
...
...
python/sglang/srt/layers/token_attention.py
View file @
396a6924
...
@@ -312,7 +312,7 @@ def token_attention_fwd(
...
@@ -312,7 +312,7 @@ def token_attention_fwd(
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
total_num_tokens
,
total_num_tokens
,
sm_scale
=
None
,
sm_scale
,
logit_cap
=-
1
,
logit_cap
=-
1
,
att_m
=
None
,
att_m
=
None
,
):
):
...
@@ -320,7 +320,6 @@ def token_attention_fwd(
...
@@ -320,7 +320,6 @@ def token_attention_fwd(
att_m
=
torch
.
empty
(
att_m
=
torch
.
empty
(
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
(
q
.
shape
[
-
2
],
total_num_tokens
),
dtype
=
REDUCE_TORCH_TYPE
,
device
=
"cuda"
)
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
if
sm_scale
is
None
else
sm_scale
_token_att_m_fwd
(
_token_att_m_fwd
(
q
,
q
,
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
396a6924
...
@@ -75,6 +75,7 @@ class Req:
...
@@ -75,6 +75,7 @@ class Req:
"""Store all inforamtion of a request."""
"""Store all inforamtion of a request."""
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
def
__init__
(
self
,
rid
,
origin_input_text
,
origin_input_ids
):
# Input and output info
self
.
rid
=
rid
self
.
rid
=
rid
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_text
=
origin_input_text
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
self
.
origin_input_ids_unpadded
=
origin_input_ids
# Before image padding
...
@@ -97,6 +98,11 @@ class Req:
...
@@ -97,6 +98,11 @@ class Req:
self
.
image_offset
=
0
self
.
image_offset
=
0
self
.
pad_value
=
None
self
.
pad_value
=
None
# Prefix info
self
.
extend_input_len
=
0
self
.
prefix_indices
=
[]
self
.
last_node
=
None
# Sampling parameters
# Sampling parameters
self
.
sampling_params
=
None
self
.
sampling_params
=
None
self
.
stream
=
False
self
.
stream
=
False
...
@@ -105,11 +111,6 @@ class Req:
...
@@ -105,11 +111,6 @@ class Req:
self
.
tokenizer
=
None
self
.
tokenizer
=
None
self
.
finished_reason
=
None
self
.
finished_reason
=
None
# Prefix info
self
.
extend_input_len
=
0
self
.
prefix_indices
=
[]
self
.
last_node
=
None
# Logprobs
# Logprobs
self
.
return_logprob
=
False
self
.
return_logprob
=
False
self
.
logprob_start_len
=
0
self
.
logprob_start_len
=
0
...
@@ -261,35 +262,36 @@ class Req:
...
@@ -261,35 +262,36 @@ class Req:
class
Batch
:
class
Batch
:
"""Store all inforamtion of a batch."""
"""Store all inforamtion of a batch."""
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
TokenToKVPool
tree_cache
:
RadixCache
tree_cache
:
RadixCache
#
b
atched arguments to model runner
#
B
atched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_ids
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
prefix_lens
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
position_ids_offsets
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
int
=
None
#
f
or processing logprobs
#
F
or processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
#
f
or multimodal
#
F
or multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_sizes
:
List
[
List
[
int
]]
=
None
image_offsets
:
List
[
int
]
=
None
image_offsets
:
List
[
int
]
=
None
#
o
ther arguments for control
#
O
ther arguments for control
output_ids
:
torch
.
Tensor
=
None
output_ids
:
torch
.
Tensor
=
None
extend_num_tokens
:
int
=
None
extend_num_tokens
:
int
=
None
#
b
atched sampling params
#
B
atched sampling params
temperatures
:
torch
.
Tensor
=
None
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
...
@@ -312,8 +314,8 @@ class Batch:
...
@@ -312,8 +314,8 @@ class Batch:
def
is_empty
(
self
):
def
is_empty
(
self
):
return
len
(
self
.
reqs
)
==
0
return
len
(
self
.
reqs
)
==
0
# whether batch has at least 1 streaming request
def
has_stream
(
self
)
->
bool
:
def
has_stream
(
self
)
->
bool
:
# Return whether batch has at least 1 streaming request
return
any
(
r
.
stream
for
r
in
self
.
reqs
)
return
any
(
r
.
stream
for
r
in
self
.
reqs
)
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
def
prepare_for_extend
(
self
,
vocab_size
:
int
,
int_token_logit_bias
:
torch
.
Tensor
):
...
@@ -347,7 +349,7 @@ class Batch:
...
@@ -347,7 +349,7 @@ class Batch:
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
position_ids_offsets
=
torch
.
zeros
((
bs
,),
dtype
=
torch
.
int32
,
device
=
device
)
# Alloc mem
# Alloc
ate
mem
ory
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
seq_lens
,
prefix_lens
=
np
.
array
(
seq_lens
),
np
.
array
(
prefix_lens
)
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
...
@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
...
@@ -703,7 +705,6 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
return
probs_sort
,
probs_idx
return
probs_sort
,
probs_idx
@
dataclass
@
dataclass
class
InputMetadata
:
class
InputMetadata
:
"""Store all inforamtion of a forward pass."""
"""Store all inforamtion of a forward pass."""
...
@@ -711,110 +712,37 @@ class InputMetadata:
...
@@ -711,110 +712,37 @@ class InputMetadata:
forward_mode
:
ForwardMode
forward_mode
:
ForwardMode
batch_size
:
int
batch_size
:
int
total_num_tokens
:
int
total_num_tokens
:
int
max_seq_len
:
int
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
start_loc
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
prefix_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
TokenToKVPool
#
f
or extend
#
F
or extend
extend_seq_lens
:
torch
.
Tensor
=
None
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
=
None
extend_start_loc
:
torch
.
Tensor
max_
extend_
len
:
int
=
0
extend_
no_prefix
:
bool
# Output location of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_loc
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
int
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
int
=
None
# Output options
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
# for flashinfer
# Trition attention backend
qo_indptr
:
torch
.
Tensor
=
None
triton_max_seq_len
:
int
=
0
kv_indptr
:
torch
.
Tensor
=
None
triton_max_extend_len
:
int
=
0
kv_indices
:
torch
.
Tensor
=
None
triton_start_loc
:
torch
.
Tensor
=
None
kv_last_page_len
:
torch
.
Tensor
=
None
triton_prefix_lens
:
torch
.
Tensor
=
None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
self
.
seq_lens
else
:
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
self
.
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
@
classmethod
def
create
(
def
create
(
cls
,
cls
,
...
@@ -830,14 +758,20 @@ class InputMetadata:
...
@@ -830,14 +758,20 @@ class InputMetadata:
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
):
):
if
not
model_runner
.
server_args
.
disable_flashinfer
:
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
)
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
extend_seq_lens
=
extend_start_loc
=
extend_no_prefix
=
None
if
not
model_runner
.
server_args
.
disable_flashinfer
:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens
=
None
else
:
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
else
:
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
...
@@ -855,22 +789,27 @@ class InputMetadata:
...
@@ -855,22 +789,27 @@ class InputMetadata:
),
),
device
=
"cuda"
,
device
=
"cuda"
,
)
)
extend_seq_lens
=
seq_lens
-
prefix_lens
extend_start_loc
=
torch
.
zeros_like
(
seq_lens
)
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
extend_seq_lens
[:
-
1
],
dim
=
0
)
extend_no_prefix
=
torch
.
all
(
prefix_lens
==
0
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
total_num_tokens
=
total_num_tokens
,
total_num_tokens
=
total_num_tokens
,
max_seq_len
=
max_seq_len
,
req_pool_indices
=
req_pool_indices
,
req_pool_indices
=
req_pool_indices
,
start_loc
=
start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
prefix_lens
=
prefix_lens
,
positions
=
positions
,
positions
=
positions
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
out_cache_cont_end
=
out_cache_cont_end
,
extend_seq_lens
=
extend_seq_lens
,
extend_start_loc
=
extend_start_loc
,
extend_no_prefix
=
extend_no_prefix
,
return_logprob
=
return_logprob
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
...
@@ -878,14 +817,96 @@ class InputMetadata:
...
@@ -878,14 +817,96 @@ class InputMetadata:
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
)
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
if
model_runner
.
server_args
.
disable_flashinfer
:
ret
.
init_extend_args
()
(
ret
.
triton_max_seq_len
,
ret
.
triton_max_extend_len
,
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
triton_start_loc
,
ret
.
init_flashinfer_args
(
ret
.
triton_prefix_lens
)
=
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
)
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
model_runner
.
model_config
.
head_dim
,
)
return
ret
return
ret
def
init_flashinfer_args
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
):
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
if
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
seq_lens
else
:
paged_kernel_lens
=
prefix_lens
kv_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
kv_indices
=
torch
.
cat
(
[
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
]
for
i
in
range
(
batch_size
)
],
dim
=
0
,
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
forward_mode
==
ForwardMode
.
DECODE
:
model_runner
.
flashinfer_decode_wrapper
.
end_forward
()
model_runner
.
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
else
:
# extend part
qo_indptr
=
torch
.
zeros
(
(
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
)
def
init_triton_args
(
forward_mode
,
seq_lens
,
prefix_lens
):
batch_size
=
len
(
seq_lens
)
max_seq_len
=
int
(
torch
.
max
(
seq_lens
))
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
[
1
:]
=
torch
.
cumsum
(
seq_lens
[:
-
1
],
dim
=
0
)
if
forward_mode
==
ForwardMode
.
DECODE
:
max_extend_len
=
None
else
:
extend_seq_lens
=
seq_lens
-
prefix_lens
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
python/sglang/srt/managers/controller/model_runner.py
View file @
396a6924
...
@@ -182,39 +182,39 @@ class ModelRunner:
...
@@ -182,39 +182,39 @@ class ModelRunner:
return
c
return
c
def
init_flash_infer
(
self
):
def
init_flash_infer
(
self
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
self
.
server_args
.
disable_flashinfer
:
from
flashinfer
import
(
self
.
flashinfer_prefill_wrapper_ragged
=
None
BatchDecodeWithPagedKVCacheWrapper
,
self
.
flashinfer_prefill_wrapper_paged
=
None
BatchPrefillWithPagedKVCacheWrapper
,
self
.
flashinfer_decode_wrapper
=
None
BatchPrefillWithRaggedKVCacheWrapper
,
return
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
from
flashinfer
import
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
BatchDecodeWithPagedKVCacheWrapper
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
BatchPrefillWithPagedKVCacheWrapper
,
):
BatchPrefillWithRaggedKVCacheWrapper
,
use_tensor_cores
=
True
)
else
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
use_tensor_cores
=
False
workspace_buffers
=
torch
.
empty
(
if
not
_grouped_size_compiled_for_decode_kernels
(
2
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
)
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
self
.
flashinfer_prefill_wrapper_ragged
=
(
):
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
use_tensor_cores
=
True
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
else
:
else
:
self
.
flashinfer_prefill_wrapper_ragged
=
(
use_tensor_cores
=
False
self
.
flashinfer_prefill_wrapper_paged
)
=
None
workspace_buffers
=
torch
.
empty
(
self
.
flashinfer_decode_wrapper
=
None
3
,
96
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
flashinfer_prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
0
],
"NHD"
)
self
.
flashinfer_prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
1
],
"NHD"
)
self
.
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
forward_extend
(
self
,
batch
:
Batch
):
def
forward_extend
(
self
,
batch
:
Batch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment