Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a7c3f74b
Unverified
Commit
a7c3f74b
authored
Apr 07, 2025
by
Chunan Zeng
Committed by
GitHub
Apr 07, 2025
Browse files
[FA3 Feature] Support multi modal Llama-3.2-11B-Vision-Instruct (#5103)
parent
5a144a8a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
9 deletions
+113
-9
benchmark/mmmu/bench_sglang.py
benchmark/mmmu/bench_sglang.py
+1
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+111
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
No files found.
benchmark/mmmu/bench_sglang.py
View file @
a7c3f74b
...
@@ -86,8 +86,8 @@ def eval_mmmu(args):
...
@@ -86,8 +86,8 @@ def eval_mmmu(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
args
=
add_common_sglang_args_and_parse
(
parser
)
EvalArgs
.
add_cli_args
(
parser
)
EvalArgs
.
add_cli_args
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
eval_mmmu
(
args
)
eval_mmmu
(
args
)
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a7c3f74b
...
@@ -42,6 +42,16 @@ class FlashAttentionMetadata:
...
@@ -42,6 +42,16 @@ class FlashAttentionMetadata:
# Page table, the index of KV Cache Tables/Blocks
# Page table, the index of KV Cache Tables/Blocks
page_table
:
torch
.
Tensor
=
None
page_table
:
torch
.
Tensor
=
None
# Encoder metadata
# Cumulative sequence lengths for encoder key
encoder_cu_seqlens_k
:
torch
.
Tensor
=
None
# Maximum sequence length for encoder key
encoder_max_seq_len_k
:
int
=
0
# Sequence lengths for the forward batch
encoder_lens_int32
:
torch
.
Tensor
=
None
# Page table for the encoder
encoder_page_table
:
torch
.
Tensor
=
None
@
dataclass
@
dataclass
class
LocalAttentionMetadata
:
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
=
None
# cu_seqlens_q for local attention
local_query_start_loc
:
torch
.
Tensor
=
None
# cu_seqlens_q for local attention
...
@@ -435,6 +445,30 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -435,6 +445,30 @@ class FlashAttentionBackend(AttentionBackend):
)
)
metadata
.
local_attn_metadata
=
local_metadata
metadata
.
local_attn_metadata
=
local_metadata
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
forward_batch
.
encoder_lens
.
numel
()
==
1
),
"Only encoder size 1 is supported for now"
metadata
.
encoder_lens_int32
=
forward_batch
.
encoder_lens
.
to
(
torch
.
int32
)
metadata
.
encoder_cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
encoder_lens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
metadata
.
encoder_max_seq_len_k
=
metadata
.
encoder_lens_int32
.
max
().
item
()
metadata
.
encoder_page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
:
metadata
.
encoder_max_seq_len_k
]
# Currently only support forward_batch.encoder_lens.numel() == 1
metadata
.
page_table
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
,
metadata
.
encoder_max_seq_len_k
:
(
metadata
.
encoder_max_seq_len_k
+
metadata
.
max_seq_len_k
),
]
# Convert the page table to a strided format which is needed by FA3 API
# Convert the page table to a strided format which is needed by FA3 API
if
self
.
page_size
>
1
:
if
self
.
page_size
>
1
:
self
.
strided_indices
=
torch
.
arange
(
self
.
strided_indices
=
torch
.
arange
(
...
@@ -486,6 +520,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -486,6 +520,7 @@ class FlashAttentionBackend(AttentionBackend):
if
layer
.
sliding_window_size
is
not
None
if
layer
.
sliding_window_size
is
not
None
else
(
-
1
,
-
1
)
else
(
-
1
,
-
1
)
)
)
causal
=
not
layer
.
is_cross_attention
# Check if we should use local attention
# Check if we should use local attention
use_local_attn
=
(
use_local_attn
=
(
...
@@ -521,6 +556,12 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -521,6 +556,12 @@ class FlashAttentionBackend(AttentionBackend):
value_cache
=
value_cache
.
view
(
value_cache
=
value_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
)
)
if
layer
.
is_cross_attention
:
page_table
=
metadata
.
encoder_page_table
cache_seqlens
=
metadata
.
encoder_lens_int32
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
o
=
flash_attn_with_kvcache
(
o
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
k_cache
=
key_cache
,
...
@@ -531,7 +572,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -531,7 +572,7 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
k_descale
=
layer
.
k_scale
,
...
@@ -614,6 +655,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -614,6 +655,7 @@ class FlashAttentionBackend(AttentionBackend):
if
layer
.
sliding_window_size
is
not
None
if
layer
.
sliding_window_size
is
not
None
else
(
-
1
,
-
1
)
else
(
-
1
,
-
1
)
)
)
causal
=
not
layer
.
is_cross_attention
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
...
@@ -627,17 +669,27 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -627,17 +669,27 @@ class FlashAttentionBackend(AttentionBackend):
)
)
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_reshaped
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
layer
.
is_cross_attention
:
page_table
=
metadata
.
encoder_page_table
cache_seqlens
=
metadata
.
encoder_lens_int32
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
else
:
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
o
=
flash_attn_with_kvcache
(
o
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
q
=
q_reshaped
,
k_cache
=
key_cache
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
v_cache
=
value_cache
,
page_table
=
metadata
.
page_table
,
page_table
=
page_table
,
cache_seqlens
=
metadata
.
cache_seqlens
_int32
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
metadata
.
cu_seqlens_k
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
1
,
max_seqlen_q
=
1
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
layer
.
k_scale
,
k_descale
=
layer
.
k_scale
,
...
@@ -733,6 +785,21 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -733,6 +785,21 @@ class FlashAttentionBackend(AttentionBackend):
),
),
}
}
self
.
encoder_metadata
=
{
"encoder_page_table"
:
torch
.
zeros
(
max_bs
,
self
.
max_context_len
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
),
"encoder_lens_int32"
:
torch
.
zeros
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
"encoder_cu_seqlens_k"
:
torch
.
zeros
(
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
),
}
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
...
@@ -818,6 +885,19 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -818,6 +885,19 @@ class FlashAttentionBackend(AttentionBackend):
self
.
target_verify_metadata
[
bs
]
=
metadata
self
.
target_verify_metadata
[
bs
]
=
metadata
if
encoder_lens
is
not
None
:
encoder_bs
=
encoder_lens
.
numel
()
metadata
.
encoder_lens_int32
=
self
.
encoder_metadata
[
"encoder_lens_int32"
][
:
encoder_bs
]
metadata
.
encoder_cu_seqlens_k
=
self
.
encoder_metadata
[
"encoder_cu_seqlens_k"
][:
(
encoder_bs
+
1
)]
metadata
.
encoder_page_table
=
self
.
encoder_metadata
[
"encoder_page_table"
][
req_pool_indices
,
:
]
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -903,6 +983,30 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -903,6 +983,30 @@ class FlashAttentionBackend(AttentionBackend):
page_table
=
self
.
req_to_token
[
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
page_table
=
self
.
req_to_token
[
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
page_table
)
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
page_table
)
if
encoder_lens
is
not
None
:
# Only support encoder size 1 for now
metadata
.
encoder_max_seq_len_k
=
encoder_lens
[
0
]
metadata
.
encoder_lens_int32
.
copy_
(
encoder_lens
[:
1
])
metadata
.
encoder_cu_seqlens_k
.
copy_
(
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
metadata
.
encoder_lens_int32
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
),
)
)
metadata
.
encoder_page_table
[:,
:
metadata
.
encoder_max_seq_len_k
].
copy_
(
self
.
req_to_token
[
req_pool_indices
,
:
metadata
.
encoder_max_seq_len_k
]
)
# Update the regular page table
page_table
=
self
.
req_to_token
[
req_pool_indices
,
metadata
.
encoder_max_seq_len_k
:
(
metadata
.
encoder_max_seq_len_k
+
metadata
.
max_seq_len_k
),
]
metadata
.
page_table
[:,
:
metadata
.
max_seq_len_k
].
copy_
(
page_table
)
self
.
forward_metadata
=
metadata
self
.
forward_metadata
=
metadata
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
@@ -956,7 +1060,7 @@ class FlashAttentionMultiStepBackend:
...
@@ -956,7 +1060,7 @@ class FlashAttentionMultiStepBackend:
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
batch_size
*
self
.
topk
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
encoder_lens
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
)
)
...
@@ -973,7 +1077,7 @@ class FlashAttentionMultiStepBackend:
...
@@ -973,7 +1077,7 @@ class FlashAttentionMultiStepBackend:
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
encoder_lens
=
None
,
encoder_lens
=
forward_batch
.
encoder_lens
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
spec_info
=
forward_batch
.
spec_info
,
spec_info
=
forward_batch
.
spec_info
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a7c3f74b
...
@@ -886,7 +886,7 @@ class ModelRunner:
...
@@ -886,7 +886,7 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
)
)
logger
.
warning
(
logger
.
warning
(
"FlashAttention v3 Backend is in Beta.
Multimodal, FP8, and Speculative Decoding are
not supported."
"FlashAttention v3 Backend is in Beta.
FP8 is
not supported."
)
)
from
sglang.srt.layers.attention.flashattention_backend
import
(
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
FlashAttentionBackend
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment