Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a2cb5913
Unverified
Commit
a2cb5913
authored
Jun 02, 2025
by
Ke Bao
Committed by
GitHub
Jun 02, 2025
Browse files
Add draft extend CUDA graph for flashinfer backend (#6805)
parent
55444ed6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
170 additions
and
3 deletions
+170
-3
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+40
-0
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+32
-0
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+3
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+10
-1
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+85
-1
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
a2cb5913
...
@@ -358,6 +358,35 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -358,6 +358,35 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
self
.
prefill_cuda_graph_metadata
[
bs
]
=
prefill_wrappers
self
.
prefill_cuda_graph_metadata
[
bs
]
=
prefill_wrappers
self
.
forward_metadata
=
PrefillMetadata
(
prefill_wrappers
,
False
,
False
)
self
.
forward_metadata
=
PrefillMetadata
(
prefill_wrappers
,
False
,
False
)
elif
forward_mode
.
is_draft_extend
():
prefill_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
prefill_wrappers
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
backend
=
"fa2"
,
use_cuda_graph
=
True
,
qo_indptr_buf
=
self
.
cuda_graph_qo_indptr
[
i
][:
bs
+
1
],
paged_kv_indptr_buf
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buf
=
self
.
kv_last_page_len
[:
bs
],
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
prefill_wrappers
,
use_ragged
=
False
,
encoder_lens
=
encoder_lens
,
spec_info
=
spec_info
,
)
self
.
prefill_cuda_graph_metadata
[
bs
]
=
prefill_wrappers
self
.
forward_metadata
=
PrefillMetadata
(
prefill_wrappers
,
False
,
False
)
else
:
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
...
@@ -392,6 +421,17 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -392,6 +421,17 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
)
)
elif
forward_mode
.
is_draft_extend
():
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
use_ragged
=
False
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
)
else
:
else
:
raise
ValueError
(
"Invalid forward mode"
)
raise
ValueError
(
"Invalid forward mode"
)
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
a2cb5913
...
@@ -278,6 +278,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -278,6 +278,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
)
)
self
.
prefill_cuda_graph_metadata
[
bs
]
=
verify_wrapper
self
.
prefill_cuda_graph_metadata
[
bs
]
=
verify_wrapper
self
.
forward_metadata
=
PrefillMetadata
(
verify_wrapper
,
False
)
self
.
forward_metadata
=
PrefillMetadata
(
verify_wrapper
,
False
)
elif
forward_mode
.
is_draft_extend
():
draft_extend_wrapper
=
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
use_cuda_graph
=
True
,
qo_indptr
=
self
.
cuda_graph_qo_indptr
[:
bs
+
1
],
kv_indptr
=
self
.
cuda_graph_kv_indptr
[:
bs
+
1
],
kv_indices
=
self
.
cuda_graph_kv_indices
,
kv_len_arr
=
self
.
cuda_graph_kv_lens
[:
bs
],
backend
=
"auto"
,
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrapper_paged
=
draft_extend_wrapper
,
use_ragged
=
False
,
spec_info
=
spec_info
,
)
self
.
prefill_cuda_graph_metadata
[
bs
]
=
draft_extend_wrapper
self
.
forward_metadata
=
PrefillMetadata
(
draft_extend_wrapper
,
False
)
else
:
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
...
@@ -325,6 +347,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -325,6 +347,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
use_ragged
=
False
,
use_ragged
=
False
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
)
)
elif
forward_mode
.
is_draft_extend
():
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrapper_paged
=
self
.
prefill_cuda_graph_metadata
[
bs
],
use_ragged
=
False
,
spec_info
=
spec_info
,
)
else
:
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
a2cb5913
...
@@ -80,7 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -80,7 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
extend_seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
extend_seq_lens
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
accept_length
=
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
accept_length
=
(
torch
.
ones
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
*
self
.
num_tokens_per_bs
)
# Capture
# Capture
try
:
try
:
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
a2cb5913
...
@@ -156,6 +156,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -156,6 +156,7 @@ class EAGLEWorker(TpModelWorker):
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
if
not
global_server_args_dict
[
"use_mla_backend"
]:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
FlashInferAttnBackend
,
FlashInferMultiStepDraftBackend
,
FlashInferMultiStepDraftBackend
,
)
)
...
@@ -164,8 +165,13 @@ class EAGLEWorker(TpModelWorker):
...
@@ -164,8 +165,13 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
FlashInferAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
else
:
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
FlashInferMLAMultiStepDraftBackend
,
FlashInferMLAMultiStepDraftBackend
,
)
)
...
@@ -174,7 +180,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -174,7 +180,10 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
self
.
draft_extend_attn_backend
=
None
self
.
draft_extend_attn_backend
=
FlashInferMLAAttnBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
,
)
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
padded_static_len
=
self
.
speculative_num_steps
+
1
self
.
has_prefill_wrapper_verify
=
True
self
.
has_prefill_wrapper_verify
=
True
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
...
...
test/srt/test_eagle_infer.py
View file @
a2cb5913
...
@@ -19,6 +19,7 @@ from sglang.test.few_shot_gsm8k import run_eval
...
@@ -19,6 +19,7 @@ from sglang.test.few_shot_gsm8k import run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
CustomTestCase
,
...
@@ -602,6 +603,7 @@ class TestEAGLEDraftExtend(CustomTestCase):
...
@@ -602,6 +603,7 @@ class TestEAGLEDraftExtend(CustomTestCase):
"fa3"
,
"fa3"
,
],
],
)
)
cls
.
accept_len_threshold
=
1.50
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
...
@@ -636,7 +638,89 @@ class TestEAGLEDraftExtend(CustomTestCase):
...
@@ -636,7 +638,89 @@ class TestEAGLEDraftExtend(CustomTestCase):
acc_length
=
1.0
acc_length
=
1.0
print
(
f
"
{
acc_length
=
}
"
)
print
(
f
"
{
acc_length
=
}
"
)
self
.
assertGreater
(
acc_length
,
1.50
)
self
.
assertGreater
(
acc_length
,
self
.
accept_len_threshold
)
class
TestEAGLEDraftExtendFlashinfer
(
TestEAGLEDraftExtend
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
1
,
"--speculative-eagle-topk"
,
1
,
"--speculative-num-draft-tokens"
,
2
,
"--max-running-requests"
,
4
,
"--attention-backend"
,
"flashinfer"
,
],
)
cls
.
accept_len_threshold
=
1.50
class
TestEAGLEDraftExtendTriton
(
TestEAGLEDraftExtend
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
1
,
"--speculative-eagle-topk"
,
1
,
"--speculative-num-draft-tokens"
,
2
,
"--max-running-requests"
,
4
,
"--attention-backend"
,
"triton"
,
],
)
cls
.
accept_len_threshold
=
1.50
class
TestEAGLEDraftExtendFlashinferMLA
(
TestEAGLEDraftExtend
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_MODEL_NAME_FOR_TEST_MLA
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-num-steps"
,
1
,
"--speculative-eagle-topk"
,
1
,
"--speculative-num-draft-tokens"
,
2
,
"--max-running-requests"
,
4
,
"--attention-backend"
,
"flashinfer"
,
],
)
cls
.
accept_len_threshold
=
1.85
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment