Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f440baa1
Unverified
Commit
f440baa1
authored
Oct 18, 2025
by
ykcombat
Committed by
GitHub
Oct 18, 2025
Browse files
[Feature] Reuse flashinfer workspace for PD-Multiplexing. (#11540)
parent
2bc3fcd4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
2 deletions
+13
-2
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+3
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+9
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
No files found.
python/sglang/srt/layers/attention/attention_registry.py
View file @
f440baa1
...
@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner):
...
@@ -34,7 +34,9 @@ def create_flashinfer_backend(runner):
or
not
runner
.
plan_stream_for_flashinfer
or
not
runner
.
plan_stream_for_flashinfer
):
):
runner
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
runner
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
return
FlashInferAttnBackend
(
runner
)
return
FlashInferAttnBackend
(
runner
,
init_new_workspace
=
runner
.
init_new_workspace
)
else
:
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
FlashInferMLAAttnBackend
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
f440baa1
...
@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -118,6 +118,7 @@ class FlashInferAttnBackend(AttentionBackend):
skip_prefill
:
bool
=
False
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
init_new_workspace
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -192,7 +193,14 @@ class FlashInferAttnBackend(AttentionBackend):
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
self
.
workspace_buffer
=
global_workspace_buffer
if
init_new_workspace
:
self
.
workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
model_runner
.
device
,
)
else
:
self
.
workspace_buffer
=
global_workspace_buffer
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
[
self
.
kv_indptr
=
[
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f440baa1
...
@@ -284,6 +284,7 @@ class ModelRunner:
...
@@ -284,6 +284,7 @@ class ModelRunner:
self
.
use_mla_backend
=
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
use_mla_backend
=
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
attention_chunk_size
=
model_config
.
attention_chunk_size
self
.
forward_pass_id
=
0
self
.
forward_pass_id
=
0
self
.
init_new_workspace
=
False
# Apply the rank zero filter to logger
# Apply the rank zero filter to logger
if
server_args
.
show_time_cost
:
if
server_args
.
show_time_cost
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment