Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
100f5b8b
Unverified
Commit
100f5b8b
authored
Oct 01, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 01, 2024
Browse files
Simplify flashinfer dispatch (#1552)
parent
619bb6dd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
97 additions
and
76 deletions
+97
-76
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+21
-5
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+67
-69
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+3
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-0
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-1
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
100f5b8b
...
...
@@ -14,7 +14,10 @@ import torch.nn as nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.layers.attention.flashinfer_utils
import
(
WrapperDispatch
,
update_flashinfer_indices
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_hip
...
...
@@ -53,10 +56,19 @@ class FlashInferAttnBackend(AttentionBackend):
device
=
"cuda"
,
)
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
),
"Sliding window and cross attention are not supported together"
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
None
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
else
:
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
elif
model_runner
.
has_cross_attention
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
# NOTE: we do not use ragged attention when there are multiple wrappers
self
.
prefill_wrapper_ragged
=
(
...
...
@@ -88,8 +100,12 @@ class FlashInferAttnBackend(AttentionBackend):
if
self
.
num_wrappers
==
1
:
return
0
# TODO: make sure the idx is related to sliding window size
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
return
layer
.
is_cross_attention
raise
ValueError
(
f
"Unknown dispatch reason:
{
self
.
dispatch_reason
}
"
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/layers/attention/flashinfer_utils.py
View file @
100f5b8b
from
enum
import
Enum
,
auto
import
torch
import
triton
import
triton.language
as
tl
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
...
...
@@ -80,67 +87,6 @@ class FlashinferUpdater:
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
_init_indices_no_sliding_window
(
self
):
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
None
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_init_indices_sliding_window
(
self
,
wrapper_id
):
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
kv_start_idx
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
assert
not
isinstance
(
decode_wrapper
,
list
)
decode_wrapper
.
end_forward
()
...
...
@@ -189,8 +135,53 @@ class FlashinferUpdater:
1
,
)
def
update_indices_no_sliding_window
(
self
):
self
.
_init_indices_no_sliding_window
()
def
_get_indices
(
self
,
dispatch_reason
:
WrapperDispatch
=
None
,
wrapper_id
=
0
):
if
dispatch_reason
is
None
:
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
None
elif
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
self
.
kv_start_idx
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_indicess_single_wrapper
(
self
):
self
.
_get_indices
()
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
0
])
...
...
@@ -200,11 +191,13 @@ class FlashinferUpdater:
self
.
prefill_wrappers_paged
[
0
],
)
def
update_indices_
sliding_window
(
self
):
ass
ert
self
.
use_ragged
is
False
def
_
update_indices_
cross_attention
(
self
):
p
ass
def
_update_indices_sliding_window
(
self
):
assert
self
.
use_ragged
is
False
for
wrapper_id
in
range
(
2
):
self
.
_
ini
t_indices
_sliding_window
(
wrapper_id
)
self
.
_
ge
t_indices
(
WrapperDispatch
.
SLIDING_WINDOW
,
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
wrapper_id
])
else
:
...
...
@@ -233,7 +226,12 @@ def update_flashinfer_indices(
use_ragged
,
)
if
model_runner
.
sliding_window_size
is
None
:
updater
.
update_indices_no_sliding_window
()
dispatch_reason
=
model_runner
.
attn_backend
.
dispatch_reason
if
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
updater
.
_update_indices_sliding_window
()
elif
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
updater
.
_update_indices_cross_attention
()
else
:
updater
.
update_indices_sliding_window
()
assert
model_runner
.
attn_backend
.
num_wrappers
==
1
updater
.
_update_indicess_single_wrapper
()
python/sglang/srt/layers/radix_attention.py
View file @
100f5b8b
...
...
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
sliding_window_size
:
int
=
-
1
,
logit_cap
:
float
=
0.0
,
v_head_dim
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
...
...
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
def
forward
(
self
,
q
,
k
,
v
,
forward_batch
:
ForwardBatch
):
if
k
is
not
None
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
100f5b8b
...
...
@@ -231,6 +231,7 @@ class ModelRunner:
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
)
self
.
has_cross_attention
=
getattr
(
self
.
model
,
"has_cross_attention"
,
False
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
...
...
@@ -453,6 +454,10 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
assert
not
self
.
has_cross_attention
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
else
:
raise
ValueError
(
...
...
python/sglang/srt/models/gemma2.py
View file @
100f5b8b
...
...
@@ -163,12 +163,12 @@ class Gemma2Attention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
sliding_window_size
=
(
get_attention_sliding_window_size
(
config
)
if
use_sliding_window
else
None
),
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment