Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
73d4a5f8
Unverified
Commit
73d4a5f8
authored
Oct 01, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 01, 2025
Browse files
Organize spec-related data structures (#10735)
parent
7fb551a7
Changes
32
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
866 additions
and
808 deletions
+866
-808
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+10
-38
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+1
-1
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+2
-1
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+183
-750
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+4
-2
python/sglang/srt/speculative/ngram_utils.py
python/sglang/srt/speculative/ngram_utils.py
+9
-4
python/sglang/srt/speculative/ngram_worker.py
python/sglang/srt/speculative/ngram_worker.py
+1
-5
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+42
-0
python/sglang/srt/speculative/spec_utils.py
python/sglang/srt/speculative/spec_utils.py
+607
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+5
-4
test/srt/test_forward_split_prefill.py
test/srt/test_forward_split_prefill.py
+1
-2
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
73d4a5f8
...
...
@@ -821,7 +821,7 @@ class CudaGraphRunner:
self
.
model_runner
.
spec_algorithm
.
is_eagle
()
or
self
.
model_runner
.
spec_algorithm
.
is_standalone
()
):
from
sglang.srt.speculative.eagle_
utils
import
EagleVerifyInput
from
sglang.srt.speculative.eagle_
info
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
73d4a5f8
...
...
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size
,
set_dp_buffer_len
,
)
from
sglang.srt.layers.rotary_embedding
import
MRotaryEmbedding
from
sglang.srt.utils
import
(
flatten_nested_list
,
get_compiler_backend
,
is_npu
,
support_triton
,
)
from
sglang.srt.utils
import
get_compiler_backend
,
is_npu
,
support_triton
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
...
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpeculativeAlgorithm
_is_npu
=
is_npu
()
...
...
@@ -293,7 +286,7 @@ class ForwardBatch:
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
# Speculative decoding
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraft
Input
]
]
=
None
spec_info
:
Optional
[
Spec
Input
]
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
...
...
@@ -364,33 +357,14 @@ class ForwardBatch:
# For MLP sync
if
batch
.
global_num_tokens
is
not
None
:
from
sglang.srt.speculative.eagle_utils
import
(
EagleDraftInput
,
EagleVerifyInput
,
)
assert
batch
.
global_num_tokens_for_logprob
is
not
None
# process global_num_tokens and global_num_tokens_for_logprob
if
batch
.
spec_info
is
not
None
:
if
isinstance
(
batch
.
spec_info
,
EagleDraftInput
):
global_num_tokens
=
[
x
*
batch
.
spec_info
.
num_tokens_per_batch
for
x
in
batch
.
global_num_tokens
]
global_num_tokens_for_logprob
=
[
x
*
batch
.
spec_info
.
num_tokens_for_logprob_per_batch
for
x
in
batch
.
global_num_tokens_for_logprob
]
else
:
assert
isinstance
(
batch
.
spec_info
,
EagleVerifyInput
)
global_num_tokens
=
[
x
*
batch
.
spec_info
.
draft_token_num
for
x
in
batch
.
global_num_tokens
]
global_num_tokens_for_logprob
=
[
x
*
batch
.
spec_info
.
draft_token_num
for
x
in
batch
.
global_num_tokens_for_logprob
]
spec_info
:
SpecInput
=
batch
.
spec_info
global_num_tokens
,
global_num_tokens_for_logprob
=
(
spec_info
.
get_spec_adjusted_global_num_tokens
(
batch
)
)
else
:
global_num_tokens
=
batch
.
global_num_tokens
global_num_tokens_for_logprob
=
batch
.
global_num_tokens_for_logprob
...
...
@@ -669,9 +643,6 @@ class ForwardBatch:
)
def
prepare_mlp_sync_batch
(
self
,
model_runner
:
ModelRunner
):
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
assert
self
.
global_num_tokens_cpu
is
not
None
assert
self
.
global_num_tokens_for_logprob_cpu
is
not
None
...
...
@@ -768,7 +739,8 @@ class ForwardBatch:
if
self
.
extend_seq_lens
is
not
None
:
self
.
extend_seq_lens
=
self
.
_pad_tensor_to_size
(
self
.
extend_seq_lens
,
bs
)
if
self
.
spec_info
is
not
None
and
isinstance
(
self
.
spec_info
,
EagleDraftInput
):
if
self
.
spec_info
is
not
None
and
self
.
spec_info
.
is_draft_input
():
# FIXME(lsyin): remove this isinstance logic
spec_info
=
self
.
spec_info
self
.
output_cache_loc_backup
=
self
.
out_cache_loc
self
.
hidden_states_backup
=
spec_info
.
hidden_states
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
73d4a5f8
...
...
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.speculative.eagle_
utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_
info
import
EagleDraftInput
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_gathered_buffer
,
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
73d4a5f8
...
...
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
fast_topk
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.spec_utils
import
fast_topk
from
sglang.srt.utils
import
(
require_attn_tp_gather
,
require_gathered_buffer
,
...
...
python/sglang/srt/speculative/eagle_
utils
.py
→
python/sglang/srt/speculative/eagle_
info
.py
View file @
73d4a5f8
This diff is collapsed.
Click to expand it.
python/sglang/srt/speculative/eagle_worker.py
View file @
73d4a5f8
...
...
@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from
sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner
import
(
EAGLEDraftExtendCudaGraphRunner
,
)
from
sglang.srt.speculative.eagle_
utils
import
(
from
sglang.srt.speculative.eagle_
info
import
(
EagleDraftInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
(
assign_draft_cache_locs
,
fast_topk
,
generate_token_bitmask
,
select_top_k_tokens
,
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
(
empty_context
,
get_available_gpu_memory
,
...
...
python/sglang/srt/speculative/ngram_utils.py
View file @
73d4a5f8
...
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
copy
import
logging
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
import
triton
...
...
@@ -13,6 +13,7 @@ from dataclasses import dataclass
import
torch.nn.functional
as
F
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
...
...
@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.eagle_utils
import
(
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
from
sglang.srt.speculative.spec_utils
import
(
TREE_SPEC_KERNEL_AVAILABLE
,
assign_req_to_token_pool
,
create_flashinfer_kv_indices_triton
,
get_src_tgt_cache_loc
,
get_target_cache_loc
,
)
...
...
@@ -42,7 +43,7 @@ elif is_hip():
@
dataclass
class
NgramVerifyInput
:
class
NgramVerifyInput
(
SpecInput
)
:
def
__init__
(
self
,
draft_token
:
torch
.
Tensor
,
...
...
@@ -53,6 +54,7 @@ class NgramVerifyInput:
retrive_next_sibling
:
torch
.
Tensor
,
draft_token_num
:
int
,
):
super
().
__init__
(
SpecInputType
.
NGRAM_VERIFY
)
self
.
draft_token
=
draft_token
self
.
custom_mask
=
tree_mask
self
.
positions
=
positions
...
...
@@ -62,6 +64,9 @@ class NgramVerifyInput:
self
.
draft_token_num
=
draft_token_num
self
.
device
=
self
.
custom_mask
.
device
def
get_spec_adjust_token_coefficient
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
draft_token_num
,
self
.
draft_token_num
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
if
batch
.
forward_mode
.
is_idle
():
return
...
...
python/sglang/srt/speculative/ngram_worker.py
View file @
73d4a5f8
import
logging
import
os
import
threading
import
time
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.speculative.cpp_ngram.ngram_cache
import
NgramCache
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
broadcast_pyobj
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/speculative/spec_info.py
View file @
73d4a5f8
from
abc
import
ABC
,
abstractmethod
from
enum
import
IntEnum
,
auto
from
typing
import
List
,
Tuple
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
class
SpeculativeAlgorithm
(
IntEnum
):
...
...
@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
if
name
is
not
None
:
name
=
name
.
upper
()
return
name_map
[
name
]
class
SpecInputType
(
IntEnum
):
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
EAGLE_DRAFT
=
auto
()
EAGLE_VERIFY
=
auto
()
NGRAM_VERIFY
=
auto
()
class
SpecInput
(
ABC
):
def
__init__
(
self
,
spec_input_type
:
SpecInputType
):
self
.
spec_input_type
=
spec_input_type
def
is_draft_input
(
self
)
->
bool
:
# FIXME: remove this function which is only used for assertion
# or use another variable name like `draft_input` to substitute `spec_info`
return
self
.
spec_input_type
==
SpecInputType
.
EAGLE_DRAFT
def
is_verify_input
(
self
)
->
bool
:
return
self
.
spec_input_type
in
{
SpecInputType
.
EAGLE_VERIFY
,
SpecInputType
.
NGRAM_VERIFY
,
}
@
abstractmethod
def
get_spec_adjust_token_coefficient
(
self
)
->
Tuple
[
int
,
int
]:
pass
def
get_spec_adjusted_global_num_tokens
(
self
,
forward_batch
:
ModelWorkerBatch
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
c1
,
c2
=
self
.
get_spec_adjust_token_coefficient
()
global_num_tokens
=
[
x
*
c1
for
x
in
forward_batch
.
global_num_tokens
]
global_num_tokens_for_logprob
=
[
x
*
c2
for
x
in
forward_batch
.
global_num_tokens_for_logprob
]
return
global_num_tokens
,
global_num_tokens_for_logprob
python/sglang/srt/speculative/spec_utils.py
0 → 100644
View file @
73d4a5f8
This diff is collapsed.
Click to expand it.
python/sglang/srt/two_batch_overlap.py
View file @
73d4a5f8
...
...
@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
if
TYPE_CHECKING
:
...
...
@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
def
get_token_num_per_seq
(
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
):
if
forward_mode
.
is_target_verify
():
return
spec_info
.
draft_token_num
...
...
@@ -273,7 +274,7 @@ def compute_split_token_index(
def
compute_split_indices_for_cuda_graph_replay
(
forward_mode
:
ForwardMode
,
cuda_graph_num_tokens
:
int
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
forward_mode_for_tbo_split
=
(
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
...
...
@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
forward_mode
:
ForwardMode
,
bs
:
int
,
num_token_non_padded
:
int
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
token_num_per_seq
=
get_token_num_per_seq
(
forward_mode
=
forward_mode
,
spec_info
=
spec_info
...
...
test/srt/test_forward_split_prefill.py
View file @
73d4a5f8
...
...
@@ -7,7 +7,6 @@ or
python3 test_forward_split_prefill.py
"""
import
time
import
unittest
import
numpy
as
np
...
...
@@ -16,7 +15,7 @@ import torch
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment