Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
100f5b8b
"torchvision/datapoints/_datapoint.py" did not exist on "aedd39792d07af58e55ec028ed344d63debbd281"
Unverified
Commit
100f5b8b
authored
Oct 01, 2024
by
Liangsheng Yin
Committed by
GitHub
Oct 01, 2024
Browse files
Simplify flashinfer dispatch (#1552)
parent
619bb6dd
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
97 additions
and
76 deletions
+97
-76
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+21
-5
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+67
-69
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+3
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-0
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-1
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
100f5b8b
...
@@ -14,7 +14,10 @@ import torch.nn as nn
...
@@ -14,7 +14,10 @@ import torch.nn as nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_utils
import
update_flashinfer_indices
from
sglang.srt.layers.attention.flashinfer_utils
import
(
WrapperDispatch
,
update_flashinfer_indices
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -53,10 +56,19 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -53,10 +56,19 @@ class FlashInferAttnBackend(AttentionBackend):
device
=
"cuda"
,
device
=
"cuda"
,
)
)
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
),
"Sliding window and cross attention are not supported together"
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
None
if
model_runner
.
sliding_window_size
is
not
None
:
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
self
.
num_wrappers
=
2
else
:
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
self
.
num_wrappers
=
1
elif
model_runner
.
has_cross_attention
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
# NOTE: we do not use ragged attention when there are multiple wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers
self
.
prefill_wrapper_ragged
=
(
self
.
prefill_wrapper_ragged
=
(
...
@@ -88,8 +100,12 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -88,8 +100,12 @@ class FlashInferAttnBackend(AttentionBackend):
if
self
.
num_wrappers
==
1
:
if
self
.
num_wrappers
==
1
:
return
0
return
0
# TODO: make sure the idx is related to sliding window size
if
self
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
return
layer
.
sliding_window_size
==
-
1
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
return
layer
.
is_cross_attention
raise
ValueError
(
f
"Unknown dispatch reason:
{
self
.
dispatch_reason
}
"
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/layers/attention/flashinfer_utils.py
View file @
100f5b8b
from
enum
import
Enum
,
auto
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
@
triton
.
jit
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_to_token_ptr
,
# [max_batch, max_context_len]
...
@@ -80,67 +87,6 @@ class FlashinferUpdater:
...
@@ -80,67 +87,6 @@ class FlashinferUpdater:
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
def
_init_indices_no_sliding_window
(
self
):
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
None
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_init_indices_sliding_window
(
self
,
wrapper_id
):
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
kv_start_idx
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
def
_update_decode_indices
(
self
,
decode_wrapper
):
assert
not
isinstance
(
decode_wrapper
,
list
)
assert
not
isinstance
(
decode_wrapper
,
list
)
decode_wrapper
.
end_forward
()
decode_wrapper
.
end_forward
()
...
@@ -189,8 +135,53 @@ class FlashinferUpdater:
...
@@ -189,8 +135,53 @@ class FlashinferUpdater:
1
,
1
,
)
)
def
update_indices_no_sliding_window
(
self
):
def
_get_indices
(
self
,
dispatch_reason
:
WrapperDispatch
=
None
,
wrapper_id
=
0
):
self
.
_init_indices_no_sliding_window
()
if
dispatch_reason
is
None
:
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
None
elif
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
self
.
kv_start_idx
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_indicess_single_wrapper
(
self
):
self
.
_get_indices
()
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
0
])
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
0
])
...
@@ -200,11 +191,13 @@ class FlashinferUpdater:
...
@@ -200,11 +191,13 @@ class FlashinferUpdater:
self
.
prefill_wrappers_paged
[
0
],
self
.
prefill_wrappers_paged
[
0
],
)
)
def
update_indices_
sliding_window
(
self
):
def
_
update_indices_
cross_attention
(
self
):
ass
ert
self
.
use_ragged
is
False
p
ass
def
_update_indices_sliding_window
(
self
):
assert
self
.
use_ragged
is
False
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
self
.
_
ini
t_indices
_sliding_window
(
wrapper_id
)
self
.
_
ge
t_indices
(
WrapperDispatch
.
SLIDING_WINDOW
,
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
wrapper_id
])
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
wrapper_id
])
else
:
else
:
...
@@ -233,7 +226,12 @@ def update_flashinfer_indices(
...
@@ -233,7 +226,12 @@ def update_flashinfer_indices(
use_ragged
,
use_ragged
,
)
)
if
model_runner
.
sliding_window_size
is
None
:
dispatch_reason
=
model_runner
.
attn_backend
.
dispatch_reason
updater
.
update_indices_no_sliding_window
()
if
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
updater
.
_update_indices_sliding_window
()
elif
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
updater
.
_update_indices_cross_attention
()
else
:
else
:
updater
.
update_indices_sliding_window
()
assert
model_runner
.
attn_backend
.
num_wrappers
==
1
updater
.
_update_indicess_single_wrapper
()
python/sglang/srt/layers/radix_attention.py
View file @
100f5b8b
...
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
...
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
scaling
:
float
,
scaling
:
float
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
,
layer_id
:
int
,
sliding_window_size
:
int
=
-
1
,
logit_cap
:
float
=
0.0
,
logit_cap
:
float
=
0.0
,
v_head_dim
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
...
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
...
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
self
.
logit_cap
=
logit_cap
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
def
forward
(
self
,
q
,
k
,
v
,
forward_batch
:
ForwardBatch
):
def
forward
(
self
,
q
,
k
,
v
,
forward_batch
:
ForwardBatch
):
if
k
is
not
None
:
if
k
is
not
None
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
100f5b8b
...
@@ -231,6 +231,7 @@ class ModelRunner:
...
@@ -231,6 +231,7 @@ class ModelRunner:
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
else
None
else
None
)
)
self
.
has_cross_attention
=
getattr
(
self
.
model
,
"has_cross_attention"
,
False
)
self
.
is_generation
=
is_generation_model
(
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
self
.
model_config
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
)
)
...
@@ -453,6 +454,10 @@ class ModelRunner:
...
@@ -453,6 +454,10 @@ class ModelRunner:
"Window attention is not supported in the triton attention backend. "
"Window attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
)
)
assert
not
self
.
has_cross_attention
,
(
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
...
python/sglang/srt/models/gemma2.py
View file @
100f5b8b
...
@@ -163,12 +163,12 @@ class Gemma2Attention(nn.Module):
...
@@ -163,12 +163,12 @@ class Gemma2Attention(nn.Module):
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
layer_id
=
layer_idx
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
sliding_window_size
=
(
sliding_window_size
=
(
get_attention_sliding_window_size
(
config
)
get_attention_sliding_window_size
(
config
)
if
use_sliding_window
if
use_sliding_window
else
None
else
None
),
),
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
)
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment