Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
73d4a5f8
Unverified
Commit
73d4a5f8
authored
Oct 01, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 01, 2025
Browse files
Organize spec-related data structures (#10735)
parent
7fb551a7
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
95 additions
and
117 deletions
+95
-117
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+1
-1
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+1
-1
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+9
-14
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+4
-4
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/cutlass_mla_backend.py
python/sglang/srt/layers/attention/cutlass_mla_backend.py
+3
-3
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+5
-6
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+17
-19
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+11
-15
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+3
-3
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
...sglang/srt/layers/attention/hybrid_linear_attn_backend.py
+6
-6
python/sglang/srt/layers/attention/tbo_backend.py
python/sglang/srt/layers/attention/tbo_backend.py
+6
-6
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-4
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+5
-7
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+3
-3
python/sglang/srt/layers/attention/wave_backend.py
python/sglang/srt/layers/attention/wave_backend.py
+4
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+0
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-9
python/sglang/srt/model_executor/cpu_graph_runner.py
python/sglang/srt/model_executor/cpu_graph_runner.py
+1
-1
No files found.
python/sglang/srt/constrained/outlines_jump_forward.py
View file @
73d4a5f8
...
...
@@ -37,7 +37,7 @@ except ImportError:
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post_
_
init__
# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
DISABLE_DISK_CACHE
=
get_bool_env_var
(
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
,
"true"
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
73d4a5f8
...
...
@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
# local import to avoid circular import
from
sglang.srt.speculative.eagle_
utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_
info
import
EagleDraftInput
spec_info
=
EagleDraftInput
(
topk_p
=
topk_p
,
...
...
python/sglang/srt/layers/attention/aiter_backend.py
View file @
73d4a5f8
...
...
@@ -4,18 +4,13 @@ from __future__ import annotations
end to end attention solution with aiter kernels
"""
import
math
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
(
...
...
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
try
:
from
aiter
import
(
...
...
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
if
forward_mode
.
is_decode_or_idle
():
qo_indptr
=
None
...
...
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
...
...
@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
...
...
@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
kv_start_idx
=
None
...
...
@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens
:
torch
.
Tensor
,
max_q_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
...
...
@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens
:
torch
.
Tensor
,
max_q_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
bs
=
len
(
req_pool_indices
)
...
...
@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
73d4a5f8
...
...
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
import
torch_npu
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
...
@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
get_bool_env_var
if
TYPE_CHECKING
:
...
...
@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
metadata
=
ForwardMetadata
()
...
...
@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
metadata
=
self
.
graph_metadata
[
bs
]
...
...
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
73d4a5f8
...
...
@@ -8,7 +8,7 @@ import torch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
class
AttentionBackend
(
ABC
):
...
...
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
"""Init the metadata for a forward pass for replaying a cuda graph."""
...
...
python/sglang/srt/layers/attention/cutlass_mla_backend.py
View file @
73d4a5f8
...
...
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
_is_cuda
=
is_cuda
()
if
_is_cuda
:
...
...
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
...
...
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
73d4a5f8
...
...
@@ -11,9 +11,8 @@ import triton.language as tl
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
SWAKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
"""Initialize forward metadata for capturing CUDA graph."""
metadata
=
FlashAttentionMetadata
()
...
...
@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
):
...
...
@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
forward_batch
:
ForwardBatch
,
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
# TODO: incrementally update the metadata for the later steps,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
73d4a5f8
...
...
@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_
utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.
ngram_utils
import
NgramVerify
Input
from
sglang.srt.speculative.eagle_
info
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.
spec_info
import
Spec
Input
from
sglang.srt.utils
import
(
get_int_env_var
,
is_flashinfer_available
,
...
...
@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
if
forward_mode
.
is_decode_or_idle
():
decode_wrappers
=
[]
...
...
@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
...
...
@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
...
...
@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
...
...
@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
...
...
@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
...
...
@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
use_sliding_window_kv_pool
:
bool
=
False
,
fixed_split_size
:
Optional
[
int
]
=
None
,
...
...
@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
):
# Keep the signature for type checking. It will be assigned during runtime.
...
...
@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
):
if
use_ragged
:
...
...
@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
):
for
wrapper_id
in
range
(
2
):
...
...
@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
):
for
wrapper_id
in
range
(
2
):
...
...
@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
use_sliding_window_kv_pool
:
bool
=
False
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
...
...
@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
else
:
assert
isinstance
(
spec_info
,
(
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
)
)
assert
isinstance
(
spec_info
,
SpecInput
)
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
...
...
@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
...
...
@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
)
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole
=
self
.
kv_indptr
[:,
:
bs
+
1
].
cpu
()
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
73d4a5f8
...
...
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
from
sglang.srt.utils
import
(
is_flashinfer_available
,
is_sm100_supported
,
...
...
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
if
is_flashinfer_available
():
from
flashinfer
import
(
...
...
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
if
forward_mode
.
is_decode_or_idle
():
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
...
...
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
...
...
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrapper
:
BatchMLAPagedAttentionWrapper
,
init_metadata_replay
:
bool
=
False
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
**
fast_decode_kwargs
,
):
decode_wrapper
=
decode_wrapper
or
self
.
decode_wrapper
...
...
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
init_metadata_replay
:
bool
=
False
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
**
fast_decode_kwargs
,
):
bs
=
len
(
req_pool_indices
)
...
...
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefill_wrapper_paged
:
BatchMLAPagedAttentionWrapper
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
...
...
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
):
bs
=
len
(
seq_lens
)
sm_scale
=
self
.
scaling
...
...
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
else
:
assert
isinstance
(
spec_info
,
EagleDraftInput
)
or
isinstance
(
spec_info
,
EagleVerifyInput
)
assert
isinstance
(
spec_info
,
SpecInput
)
# TODO: Support topk > 1 with custom mask
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
...
...
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
if
topk
>
1
:
raise
ValueError
(
...
...
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
...
...
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
)
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
73d4a5f8
...
...
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
# FlashMLA only supports pagesize=64
...
...
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
if
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
...
...
@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
...
...
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
73d4a5f8
...
...
@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
class
HybridAttnBackend
(
AttentionBackend
):
...
...
@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
backend
=
self
.
_select_backend
(
forward_mode
)
backend
.
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
backend
=
self
.
_select_backend
(
forward_mode
)
...
...
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
View file @
73d4a5f8
...
...
@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.models.qwen3_next
import
Qwen3HybridLinearDecoderLayer
,
fused_gdn_gating
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.models.qwen3_next
import
fused_gdn_gating
from
sglang.srt.speculative.
spec_info
import
Spec
Input
from
sglang.srt.utils
import
is_cuda
,
is_npu
if
is_cuda
():
...
...
@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
if
forward_mode
.
is_decode_or_idle
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
...
...
@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
num_padding
=
torch
.
count_nonzero
(
...
...
@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
for
attn_backend
in
self
.
attn_backend_list
:
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
for
attn_backend
in
self
.
attn_backend_list
:
...
...
python/sglang/srt/layers/attention/tbo_backend.py
View file @
73d4a5f8
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
...
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
self
.
primary
.
init_forward_metadata_capture_cuda_graph
(
bs
=
bs
,
...
...
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
self
.
primary
.
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
# capture args
capture_num_tokens
:
int
=
None
,
# replay args
...
...
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
EagleVerify
Input
],
spec_info
:
Optional
[
Spec
Input
],
# capture args
capture_num_tokens
:
int
=
None
,
# replay args
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
73d4a5f8
...
...
@@ -22,7 +22,7 @@ from sglang.srt.utils import (
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
def
logit_capping_mod
(
logit_capping_method
,
logit_cap
):
...
...
@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
assert
encoder_lens
is
None
,
"Not supported"
window_kv_indptr
=
self
.
window_kv_indptr
...
...
@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
# NOTE: encoder_lens expected to be zeros or None
...
...
@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
topk
:
int
,
speculative_num_steps
:
int
,
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
...
...
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
73d4a5f8
...
...
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
if
is_flashinfer_available
():
import
flashinfer
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
# Constants
DEFAULT_WORKSPACE_SIZE_MB
=
(
...
...
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
"""Initialize metadata for CUDA graph capture."""
metadata
=
TRTLLMMHAMetadata
()
...
...
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
"""Replay CUDA graph with new inputs."""
...
...
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
forward_batch
:
ForwardBatch
,
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
73d4a5f8
...
...
@@ -30,7 +30,7 @@ if is_flashinfer_available():
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
_is_cuda
=
is_cuda
()
...
...
@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
"""Initialize metadata for CUDA graph capture."""
...
...
@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
"""Replay CUDA graph with new inputs."""
...
...
python/sglang/srt/layers/attention/wave_backend.py
View file @
73d4a5f8
...
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
logging
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
triton
...
...
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
assert
encoder_lens
is
None
,
"Not supported"
...
...
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
# NOTE: encoder_lens expected to be zeros or None
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
73d4a5f8
...
...
@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_world_size
,
get_tp_group
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe
import
(
MoeRunnerConfig
,
...
...
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
CombineInput
,
StandardDispatcher
,
StandardDispatchOutput
,
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
73d4a5f8
...
...
@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]]
=
(
None
)
# spec_info: Optional[SpecInput] = None
spec_info
:
Optional
[
SpecInput
]
=
None
# Whether to return hidden states
return_hidden_states
:
bool
=
False
...
...
@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
,
NgramVerifyInput
]]
=
(
None
)
spec_info
:
Optional
[
SpecInput
]
=
None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
-
1
...
...
python/sglang/srt/model_executor/cpu_graph_runner.py
View file @
73d4a5f8
...
...
@@ -607,7 +607,7 @@ class CPUGraphRunner:
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_
utils
import
EagleVerifyInput
from
sglang.srt.speculative.eagle_
info
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment