Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d88ac9bc
"vscode:/vscode.git/clone" did not exist on "cadf5824e334225677e6376d837506404a299dcf"
Unverified
Commit
d88ac9bc
authored
Oct 17, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 17, 2025
Browse files
[overlap-spec] Make plan stream an option (#11724)
parent
ce11dd82
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
23 deletions
+23
-23
python/sglang/srt/environ.py
python/sglang/srt/environ.py
+3
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+2
-3
python/sglang/srt/speculative/eagle_worker_v2.py
python/sglang/srt/speculative/eagle_worker_v2.py
+15
-19
No files found.
python/sglang/srt/environ.py
View file @
d88ac9bc
...
@@ -221,6 +221,9 @@ class Envs:
...
@@ -221,6 +221,9 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE
=
EnvInt
(
4096
)
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE
=
EnvInt
(
4096
)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE
=
EnvInt
(
256
)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE
=
EnvInt
(
256
)
# Overlap Spec V2
SGLANG_ENABLE_OVERLAP_PLAN_STREAM
=
EnvBool
(
False
)
# VLM
# VLM
SGLANG_IMAGE_MAX_PIXELS
=
EnvInt
(
16384
*
28
*
28
)
SGLANG_IMAGE_MAX_PIXELS
=
EnvInt
(
16384
*
28
*
28
)
SGLANG_RESIZE_RESAMPLE
=
EnvStr
(
""
)
SGLANG_RESIZE_RESAMPLE
=
EnvStr
(
""
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
d88ac9bc
...
@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -365,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
kv_indices
=
torch
.
empty
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(
),
sum
(
forward_batch
.
extend_prefix_lens
_cpu
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
d88ac9bc
...
@@ -404,6 +404,8 @@ class ForwardBatch:
...
@@ -404,6 +404,8 @@ class ForwardBatch:
if
ret
.
positions
is
None
:
if
ret
.
positions
is
None
:
ret
.
positions
=
clamp_position
(
batch
.
seq_lens
)
ret
.
positions
=
clamp_position
(
batch
.
seq_lens
)
else
:
else
:
assert
isinstance
(
batch
.
extend_seq_lens
,
list
)
assert
isinstance
(
batch
.
extend_prefix_lens
,
list
)
ret
.
extend_seq_lens
=
torch
.
tensor
(
ret
.
extend_seq_lens
=
torch
.
tensor
(
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
batch
.
extend_seq_lens
,
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
d88ac9bc
...
@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
...
@@ -114,7 +114,7 @@ class EagleDraftInputV2Mixin:
num_draft_tokens
:
int
,
num_draft_tokens
:
int
,
draft_model_runner
:
Any
,
draft_model_runner
:
Any
,
):
):
seq_lens_cpu_
backup
=
batch
.
seq_lens_cpu
seq_lens_cpu_
=
batch
.
seq_lens_cpu
extend_num_tokens
=
len
(
batch
.
seq_lens
)
*
num_draft_tokens
extend_num_tokens
=
len
(
batch
.
seq_lens
)
*
num_draft_tokens
batch
.
spec_info
=
self
batch
.
spec_info
=
self
...
@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
...
@@ -123,8 +123,7 @@ class EagleDraftInputV2Mixin:
batch
.
seq_lens_cpu
=
batch
.
seq_lens_cpu
+
num_draft_tokens
batch
.
seq_lens_cpu
=
batch
.
seq_lens_cpu
+
num_draft_tokens
batch
.
seq_lens_sum
+=
extend_num_tokens
batch
.
seq_lens_sum
+=
extend_num_tokens
batch
.
extend_seq_lens
=
[
num_draft_tokens
for
_
in
range
(
len
(
batch
.
seq_lens
))]
batch
.
extend_seq_lens
=
[
num_draft_tokens
for
_
in
range
(
len
(
batch
.
seq_lens
))]
batch
.
extend_prefix_lens
=
seq_lens_cpu_backup
.
tolist
()
batch
.
extend_prefix_lens
=
seq_lens_cpu_
.
tolist
()
batch
.
extend_prefix_lens_cpu
=
seq_lens_cpu_backup
batch
.
extend_num_tokens
=
extend_num_tokens
batch
.
extend_num_tokens
=
extend_num_tokens
batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND_V2
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND_V2
...
...
python/sglang/srt/speculative/eagle_worker_v2.py
View file @
d88ac9bc
import
contextlib
import
logging
import
logging
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
from
torch.cuda
import
Stream
as
CudaStream
from
torch.cuda
import
Stream
as
CudaStream
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
Req
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
Req
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
...
@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
...
@@ -50,9 +52,13 @@ class EAGLEWorkerV2(EAGLEWorker):
self
.
speculative_num_steps
*
self
.
topk
,
self
.
speculative_num_draft_tokens
self
.
speculative_num_steps
*
self
.
topk
,
self
.
speculative_num_draft_tokens
)
)
self
.
tree_mask_mode
=
TreeMaskMode
.
FULL_MASK
self
.
tree_mask_mode
=
TreeMaskMode
.
FULL_MASK
self
.
plan_stream
:
CudaStream
=
torch
.
get_device_module
(
self
.
device
).
Stream
()
# TODO(lsyin): potential bugs with a separate plan stream
if
envs
.
SGLANG_ENABLE_OVERLAP_PLAN_STREAM
.
get
():
self
.
plan_stream_ctx
=
torch
.
cuda
.
stream
(
self
.
plan_stream
)
self
.
plan_stream
:
CudaStream
=
torch
.
get_device_module
(
self
.
device
).
Stream
()
self
.
plan_stream_ctx
=
torch
.
cuda
.
stream
(
self
.
plan_stream
)
else
:
self
.
plan_stream
=
None
self
.
plan_stream_ctx
=
contextlib
.
nullcontext
()
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
if
model_worker_batch
.
forward_mode
.
is_decode
():
if
model_worker_batch
.
forward_mode
.
is_decode
():
...
@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
...
@@ -232,9 +238,13 @@ class EAGLEWorkerV2(EAGLEWorker):
batch
:
ModelWorkerBatch
,
batch
:
ModelWorkerBatch
,
pre_draft_allocate_lens
:
torch
.
Tensor
,
pre_draft_allocate_lens
:
torch
.
Tensor
,
):
):
# Since batch.seq_lens is allocated in another stream, we need
# record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
batch
.
seq_lens
.
record_stream
(
torch
.
cuda
.
current_stream
())
# Parse args
# Parse args
verify_input
:
EagleVerifyInput
=
batch
.
spec_info
verify_input
:
EagleVerifyInput
=
batch
.
spec_info
seq_lens_backup
=
batch
.
seq_lens
bs
=
len
(
batch
.
seq_lens
)
bs
=
len
(
batch
.
seq_lens
)
# Batch 1: Target verify
# Batch 1: Target verify
...
@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
...
@@ -280,17 +290,8 @@ class EAGLEWorkerV2(EAGLEWorker):
accept_length
,
accept_length
,
accept_index
,
accept_index
,
)
=
verify_input
.
sample
(
batch
,
logits_output
)
)
=
verify_input
.
sample
(
batch
,
logits_output
)
new_seq_lens
=
seq_lens
_backup
+
accept_length
new_seq_lens
=
batch
.
seq_lens
+
accept_length
verify_done
=
torch
.
cuda
.
Event
()
verify_done
=
torch
.
cuda
.
Event
()
# Move the accepted tokens to the target KV cache locations
batch
.
seq_lens
=
seq_lens_backup
self
.
move_accepted_tokens_to_target_kvcache
(
batch
,
accept_index
,
accept_length
,
)
verify_done
.
record
()
verify_done
.
record
()
all_verified_id
=
predict
[
accept_index
]
all_verified_id
=
predict
[
accept_index
]
...
@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
...
@@ -341,11 +342,6 @@ class EAGLEWorkerV2(EAGLEWorker):
ret_topk_p
,
ret_topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
ret_topk_p
,
ret_topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
ret_hidden_states
=
draft_logits_output
.
hidden_states
ret_hidden_states
=
draft_logits_output
.
hidden_states
# Since seq_lens_backup's tensor is allocated in another stream, we
# need record_stream() to prevent pytorch gc and reuse the gpu memory
# while forward_stream is still running.
seq_lens_backup
.
record_stream
(
torch
.
cuda
.
current_stream
())
# Construct the return values
# Construct the return values
next_draft_input
=
EagleDraftInput
(
next_draft_input
=
EagleDraftInput
(
topk_p
=
ret_topk_p
,
topk_p
=
ret_topk_p
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment