Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
36f6fc50
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3045fb276352681f6b9075956e599dd8ef571872"
Unverified
Commit
36f6fc50
authored
Feb 10, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 10, 2025
Browse files
feat: enable ragged fa3 by default on hopper 12.4+ (#3442)
parent
d8727275
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
14 deletions
+11
-14
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+11
-14
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
36f6fc50
...
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
):
):
super
().
__init__
()
super
().
__init__
()
self
.
is_multimodal
=
model_runner
.
model_config
.
is_multimodal
# Parse constants
# Parse constants
self
.
decode_use_tensor_cores
=
should_use_tensor_core
(
self
.
decode_use_tensor_cores
=
should_use_tensor_core
(
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
,
kv_cache_dtype
=
model_runner
.
kv_cache_dtype
,
...
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
for
_
in
range
(
self
.
num_wrappers
)
for
_
in
range
(
self
.
num_wrappers
)
]
]
# Create wrappers
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
# NOTE: we do not use ragged attention when there are multiple wrappers
self
.
workspace_buffer
,
"NHD"
self
.
prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
if
self
.
num_wrappers
==
1
else
None
)
)
# Two wrappers: one for sliding window attention and one for full attention.
# Two wrappers: one for sliding window attention and one for full attention.
...
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
else
:
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
prefix_lens
=
forward_batch
.
extend_prefix_lens
# Some heuristics to check whether to use ragged forward
if
self
.
is_multimodal
:
if
forward_batch
.
extend_num_tokens
>=
4096
and
self
.
num_wrappers
==
1
:
use_ragged
=
True
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
else
:
use_ragged
=
False
use_ragged
=
False
extend_no_prefix
=
False
extend_no_prefix
=
False
else
:
use_ragged
=
True
extend_no_prefix
=
not
any
(
forward_batch
.
extend_prefix_lens_cpu
)
self
.
indices_updater_prefill
.
update
(
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
...
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
wrapper
.
end_forward
()
wrapper
.
begin_forward
(
wrapper
.
begin_forward
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
...
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
1
,
1
,
data_type
=
self
.
data_type
,
data_type
=
self
.
data_type
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
non_blocking
=
True
,
)
)
...
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
# extend part
# extend part
if
use_ragged
:
if
use_ragged
:
wrapper_ragged
.
end_forward
()
wrapper_ragged
.
begin_forward
(
wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
qo_indptr
,
qo_indptr
,
...
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
)
)
# cached part
# cached part
wrapper_paged
.
end_forward
()
wrapper_paged
.
begin_forward
(
wrapper_paged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
...
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
1
,
1
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
custom_mask
=
custom_mask
,
custom_mask
=
custom_mask
,
non_blocking
=
True
,
)
)
...
@@ -1125,6 +1121,7 @@ def fast_decode_plan(
...
@@ -1125,6 +1121,7 @@ def fast_decode_plan(
sm_scale
:
Optional
[
float
]
=
None
,
sm_scale
:
Optional
[
float
]
=
None
,
rope_scale
:
Optional
[
float
]
=
None
,
rope_scale
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
**
kwargs
,
)
->
None
:
)
->
None
:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
batch_size
=
len
(
last_page_len
)
batch_size
=
len
(
last_page_len
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment