Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
73d4a5f8
Unverified
Commit
73d4a5f8
authored
Oct 01, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 01, 2025
Browse files
Organize spec-related data structures (#10735)
parent
7fb551a7
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
95 additions
and
117 deletions
+95
-117
python/sglang/srt/constrained/outlines_jump_forward.py
python/sglang/srt/constrained/outlines_jump_forward.py
+1
-1
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+1
-1
python/sglang/srt/layers/attention/aiter_backend.py
python/sglang/srt/layers/attention/aiter_backend.py
+9
-14
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+4
-4
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/cutlass_mla_backend.py
python/sglang/srt/layers/attention/cutlass_mla_backend.py
+3
-3
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+5
-6
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+17
-19
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+11
-15
python/sglang/srt/layers/attention/flashmla_backend.py
python/sglang/srt/layers/attention/flashmla_backend.py
+3
-3
python/sglang/srt/layers/attention/hybrid_attn_backend.py
python/sglang/srt/layers/attention/hybrid_attn_backend.py
+3
-3
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
...sglang/srt/layers/attention/hybrid_linear_attn_backend.py
+6
-6
python/sglang/srt/layers/attention/tbo_backend.py
python/sglang/srt/layers/attention/tbo_backend.py
+6
-6
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-4
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+5
-7
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+3
-3
python/sglang/srt/layers/attention/wave_backend.py
python/sglang/srt/layers/attention/wave_backend.py
+4
-4
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+0
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+6
-9
python/sglang/srt/model_executor/cpu_graph_runner.py
python/sglang/srt/model_executor/cpu_graph_runner.py
+1
-1
No files found.
python/sglang/srt/constrained/outlines_jump_forward.py
View file @
73d4a5f8
...
@@ -37,7 +37,7 @@ except ImportError:
...
@@ -37,7 +37,7 @@ except ImportError:
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post_
_
init__
# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
DISABLE_DISK_CACHE
=
get_bool_env_var
(
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
,
"true"
)
DISABLE_DISK_CACHE
=
get_bool_env_var
(
"SGLANG_DISABLE_OUTLINES_DISK_CACHE"
,
"true"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
73d4a5f8
...
@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
hidden_states
=
torch
.
stack
(
hidden_states_list
,
dim
=
0
).
to
(
self
.
device
)
# local import to avoid circular import
# local import to avoid circular import
from
sglang.srt.speculative.eagle_
utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_
info
import
EagleDraftInput
spec_info
=
EagleDraftInput
(
spec_info
=
EagleDraftInput
(
topk_p
=
topk_p
,
topk_p
=
topk_p
,
...
...
python/sglang/srt/layers/attention/aiter_backend.py
View file @
73d4a5f8
...
@@ -4,18 +4,13 @@ from __future__ import annotations
...
@@ -4,18 +4,13 @@ from __future__ import annotations
end to end attention solution with aiter kernels
end to end attention solution with aiter kernels
"""
"""
import
math
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
...
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
...
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
try
:
try
:
from
aiter
import
(
from
aiter
import
(
...
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
qo_indptr
=
None
qo_indptr
=
None
...
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
...
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
...
@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
...
@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
...
@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
kv_start_idx
=
None
kv_start_idx
=
None
...
@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
...
@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens
:
torch
.
Tensor
,
extend_lens
:
torch
.
Tensor
,
max_q_len
:
int
,
max_q_len
:
int
,
max_kv_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
...
@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens
:
torch
.
Tensor
,
extend_lens
:
torch
.
Tensor
,
max_q_len
:
int
,
max_q_len
:
int
,
max_kv_len
:
int
,
max_kv_len
:
int
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
...
@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
topk
:
int
,
topk
:
int
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
73d4a5f8
...
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
...
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
import
torch
import
torch
import
torch_npu
import
torch_npu
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
...
@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_info
import
SpecInput
from
sglang.srt.utils
import
get_bool_env_var
from
sglang.srt.utils
import
get_bool_env_var
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
metadata
=
ForwardMetadata
()
metadata
=
ForwardMetadata
()
...
@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
metadata
=
self
.
graph_metadata
[
bs
]
metadata
=
self
.
graph_metadata
[
bs
]
...
...
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
73d4a5f8
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
class
AttentionBackend
(
ABC
):
class
AttentionBackend
(
ABC
):
...
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
...
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
...
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
"""Init the metadata for a forward pass for replaying a cuda graph."""
"""Init the metadata for a forward pass for replaying a cuda graph."""
...
...
python/sglang/srt/layers/attention/cutlass_mla_backend.py
View file @
73d4a5f8
...
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
...
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
...
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
...
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
if
spec_info
is
None
:
if
spec_info
is
None
:
...
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
...
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
73d4a5f8
...
@@ -11,9 +11,8 @@ import triton.language as tl
...
@@ -11,9 +11,8 @@ import triton.language as tl
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
SWAKVPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
"""Initialize forward metadata for capturing CUDA graph."""
"""Initialize forward metadata for capturing CUDA graph."""
metadata
=
FlashAttentionMetadata
()
metadata
=
FlashAttentionMetadata
()
...
@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
out_cache_loc
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
...
@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
...
@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
...
@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
...
@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
# TODO: incrementally update the metadata for the later steps,
# TODO: incrementally update the metadata for the later steps,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
73d4a5f8
...
@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
...
@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_
utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_
info
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.
ngram_utils
import
NgramVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_int_env_var
,
get_int_env_var
,
is_flashinfer_available
,
is_flashinfer_available
,
...
@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
decode_wrappers
=
[]
decode_wrappers
=
[]
...
@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
...
@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
disable_split_kv
:
Optional
[
bool
]
=
None
,
):
):
...
@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
use_sliding_window_kv_pool
:
bool
=
False
,
use_sliding_window_kv_pool
:
bool
=
False
,
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
...
@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
...
@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
if
use_ragged
:
if
use_ragged
:
...
@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
...
@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
...
@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
use_sliding_window_kv_pool
:
bool
=
False
,
use_sliding_window_kv_pool
:
bool
=
False
,
fixed_split_size
:
Optional
[
int
]
=
None
,
fixed_split_size
:
Optional
[
int
]
=
None
,
):
):
...
@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
custom_mask
=
None
else
:
else
:
assert
isinstance
(
assert
isinstance
(
spec_info
,
SpecInput
)
spec_info
,
(
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
)
)
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
req_pool_indices
,
...
@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
topk
:
int
,
topk
:
int
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
...
@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
)
)
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole
=
self
.
kv_indptr
[:,
:
bs
+
1
].
cpu
()
indptr_cpu_whole
=
self
.
kv_indptr
[:,
:
bs
+
1
].
cpu
()
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
73d4a5f8
...
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
...
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
is_flashinfer_available
,
is_flashinfer_available
,
is_sm100_supported
,
is_sm100_supported
,
...
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
...
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
(
from
flashinfer
import
(
...
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
decode_wrapper
=
BatchMLAPagedAttentionWrapper
(
...
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
...
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
...
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrapper
:
BatchMLAPagedAttentionWrapper
,
decode_wrapper
:
BatchMLAPagedAttentionWrapper
,
init_metadata_replay
:
bool
=
False
,
init_metadata_replay
:
bool
=
False
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
**
fast_decode_kwargs
,
**
fast_decode_kwargs
,
):
):
decode_wrapper
=
decode_wrapper
or
self
.
decode_wrapper
decode_wrapper
=
decode_wrapper
or
self
.
decode_wrapper
...
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
...
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr
:
torch
.
Tensor
,
q_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
init_metadata_replay
:
bool
=
False
,
init_metadata_replay
:
bool
=
False
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
**
fast_decode_kwargs
,
**
fast_decode_kwargs
,
):
):
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrapper_paged
:
BatchMLAPagedAttentionWrapper
,
prefill_wrapper_paged
:
BatchMLAPagedAttentionWrapper
,
use_ragged
:
bool
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
):
):
if
use_ragged
:
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
...
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
use_ragged
:
bool
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
=
None
,
spec_info
:
Optional
[
Spec
Input
]
=
None
,
):
):
bs
=
len
(
seq_lens
)
bs
=
len
(
seq_lens
)
sm_scale
=
self
.
scaling
sm_scale
=
self
.
scaling
...
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
...
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
custom_mask
=
None
else
:
else
:
assert
isinstance
(
spec_info
,
EagleDraftInput
)
or
isinstance
(
assert
isinstance
(
spec_info
,
SpecInput
)
spec_info
,
EagleVerifyInput
)
# TODO: Support topk > 1 with custom mask
# TODO: Support topk > 1 with custom mask
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
...
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
topk
:
int
,
topk
:
int
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
if
topk
>
1
:
if
topk
>
1
:
raise
ValueError
(
raise
ValueError
(
...
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
)
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
...
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
)
)
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
)
...
...
python/sglang/srt/layers/attention/flashmla_backend.py
View file @
73d4a5f8
...
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
...
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
# FlashMLA only supports pagesize=64
# FlashMLA only supports pagesize=64
...
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
...
@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
...
@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
...
...
python/sglang/srt/layers/attention/hybrid_attn_backend.py
View file @
73d4a5f8
...
@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...
@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
class
HybridAttnBackend
(
AttentionBackend
):
class
HybridAttnBackend
(
AttentionBackend
):
...
@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
backend
=
self
.
_select_backend
(
forward_mode
)
backend
=
self
.
_select_backend
(
forward_mode
)
backend
.
init_forward_metadata_capture_cuda_graph
(
backend
.
init_forward_metadata_capture_cuda_graph
(
...
@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
...
@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
backend
=
self
.
_select_backend
(
forward_mode
)
backend
=
self
.
_select_backend
(
forward_mode
)
...
...
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
View file @
73d4a5f8
...
@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
...
@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.models.qwen3_next
import
Qwen3HybridLinearDecoderLayer
,
fused_gdn_gating
from
sglang.srt.models.qwen3_next
import
fused_gdn_gating
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
from
sglang.srt.utils
import
is_cuda
,
is_npu
from
sglang.srt.utils
import
is_cuda
,
is_npu
if
is_cuda
():
if
is_cuda
():
...
@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
...
@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
...
@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
num_padding
=
torch
.
count_nonzero
(
num_padding
=
torch
.
count_nonzero
(
...
@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
...
@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
for
attn_backend
in
self
.
attn_backend_list
:
for
attn_backend
in
self
.
attn_backend_list
:
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
...
@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
...
@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
for
attn_backend
in
self
.
attn_backend_list
:
for
attn_backend
in
self
.
attn_backend_list
:
...
...
python/sglang/srt/layers/attention/tbo_backend.py
View file @
73d4a5f8
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
import
torch
import
torch
from
sglang.srt
import
two_batch_overlap
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
...
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
self
.
primary
.
init_forward_metadata_capture_cuda_graph
(
self
.
primary
.
init_forward_metadata_capture_cuda_graph
(
bs
=
bs
,
bs
=
bs
,
...
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
...
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
self
.
primary
.
init_forward_metadata_replay_cuda_graph
(
self
.
primary
.
init_forward_metadata_replay_cuda_graph
(
...
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
...
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
# capture args
# capture args
capture_num_tokens
:
int
=
None
,
capture_num_tokens
:
int
=
None
,
# replay args
# replay args
...
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
...
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
EagleVerify
Input
],
spec_info
:
Optional
[
Spec
Input
],
# capture args
# capture args
capture_num_tokens
:
int
=
None
,
capture_num_tokens
:
int
=
None
,
# replay args
# replay args
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
73d4a5f8
...
@@ -22,7 +22,7 @@ from sglang.srt.utils import (
...
@@ -22,7 +22,7 @@ from sglang.srt.utils import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
def
logit_capping_mod
(
logit_capping_method
,
logit_cap
):
def
logit_capping_mod
(
logit_capping_method
,
logit_cap
):
...
@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
assert
encoder_lens
is
None
,
"Not supported"
assert
encoder_lens
is
None
,
"Not supported"
window_kv_indptr
=
self
.
window_kv_indptr
window_kv_indptr
=
self
.
window_kv_indptr
...
@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
# NOTE: encoder_lens expected to be zeros or None
# NOTE: encoder_lens expected to be zeros or None
...
@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
...
@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
topk
:
int
,
topk
:
int
,
speculative_num_steps
:
int
,
speculative_num_steps
:
int
,
):
):
from
sglang.srt.speculative.
eagle
_utils
import
generate_draft_decode_kv_indices
from
sglang.srt.speculative.
spec
_utils
import
generate_draft_decode_kv_indices
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
...
...
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
73d4a5f8
...
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
...
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
if
is_flashinfer_available
():
if
is_flashinfer_available
():
import
flashinfer
import
flashinfer
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
# Constants
# Constants
DEFAULT_WORKSPACE_SIZE_MB
=
(
DEFAULT_WORKSPACE_SIZE_MB
=
(
...
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
"""Initialize metadata for CUDA graph capture."""
"""Initialize metadata for CUDA graph capture."""
metadata
=
TRTLLMMHAMetadata
()
metadata
=
TRTLLMMHAMetadata
()
...
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
"""Replay CUDA graph with new inputs."""
"""Replay CUDA graph with new inputs."""
...
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
...
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
):
):
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_capture_cuda_graph
(
...
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
...
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleD
raft
I
nput
)
assert
forward_batch
.
spec_info
.
is_d
raft
_i
nput
(
)
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
...
...
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
73d4a5f8
...
@@ -30,7 +30,7 @@ if is_flashinfer_available():
...
@@ -30,7 +30,7 @@ if is_flashinfer_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecIn
fo
from
sglang.srt.speculative.spec_info
import
SpecIn
put
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
...
@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
):
):
"""Initialize metadata for CUDA graph capture."""
"""Initialize metadata for CUDA graph capture."""
...
@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecIn
fo
],
spec_info
:
Optional
[
SpecIn
put
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
"""Replay CUDA graph with new inputs."""
"""Replay CUDA graph with new inputs."""
...
...
python/sglang/srt/layers/attention/wave_backend.py
View file @
73d4a5f8
...
@@ -2,7 +2,7 @@ from __future__ import annotations
...
@@ -2,7 +2,7 @@ from __future__ import annotations
import
logging
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
triton
import
triton
...
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
...
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.
eagle_utils
import
EagleDraftInput
,
EagleVerify
Input
from
sglang.srt.speculative.
spec_info
import
Spec
Input
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
...
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
):
):
assert
encoder_lens
is
None
,
"Not supported"
assert
encoder_lens
is
None
,
"Not supported"
...
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
...
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerify
Input
]
]
,
spec_info
:
Optional
[
Spec
Input
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
):
# NOTE: encoder_lens expected to be zeros or None
# NOTE: encoder_lens expected to be zeros or None
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
73d4a5f8
...
@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
...
@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size
,
get_moe_expert_parallel_world_size
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_rank
,
get_moe_tensor_parallel_world_size
,
get_moe_tensor_parallel_world_size
,
get_tp_group
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.distributed.device_communicators.pynccl_allocator
import
(
use_symmetric_memory
,
)
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.eplb.expert_location
import
get_global_expert_location_metadata
from
sglang.srt.layers.moe
import
(
from
sglang.srt.layers.moe
import
(
MoeRunnerConfig
,
MoeRunnerConfig
,
...
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
...
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe
,
should_use_flashinfer_trtllm_moe
,
)
)
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
from
sglang.srt.layers.moe.token_dispatcher.standard
import
(
CombineInput
,
StandardDispatcher
,
StandardDispatcher
,
StandardDispatchOutput
,
StandardDispatchOutput
,
)
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
73d4a5f8
...
@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
...
@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpeculativeAlgorithm
from
sglang.srt.speculative.ngram_utils
import
NgramVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
,
NgramVerifyInput
]]
=
(
# spec_info: Optional[SpecInput] = None
None
spec_info
:
Optional
[
SpecInput
]
=
None
)
# Whether to return hidden states
# Whether to return hidden states
return_hidden_states
:
bool
=
False
return_hidden_states
:
bool
=
False
...
@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
...
@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
# Speculative decoding
# Speculative decoding
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_algorithm
:
SpeculativeAlgorithm
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
,
NgramVerifyInput
]]
=
(
None
spec_info
:
Optional
[
SpecInput
]
=
None
)
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
-
1
hicache_consumer_index
:
int
=
-
1
...
...
python/sglang/srt/model_executor/cpu_graph_runner.py
View file @
73d4a5f8
...
@@ -607,7 +607,7 @@ class CPUGraphRunner:
...
@@ -607,7 +607,7 @@ class CPUGraphRunner:
def
get_spec_info
(
self
,
num_tokens
:
int
):
def
get_spec_info
(
self
,
num_tokens
:
int
):
spec_info
=
None
spec_info
=
None
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
if
self
.
model_runner
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_
utils
import
EagleVerifyInput
from
sglang.srt.speculative.eagle_
info
import
EagleVerifyInput
if
self
.
model_runner
.
is_draft_worker
:
if
self
.
model_runner
.
is_draft_worker
:
raise
RuntimeError
(
"This should not happen."
)
raise
RuntimeError
(
"This should not happen."
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment