Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
cde5a6e3
"router/src/infer/v2/scheduler.rs" did not exist on "3238c49121b02432bf2938c6ebfd44f06c5adc2f"
Unverified
Commit
cde5a6e3
authored
Oct 17, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 17, 2025
Browse files
Abstraction for spec worker and code cleanup (#11643)
parent
3e4c7da2
Changes
14
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
706 additions
and
460 deletions
+706
-460
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-34
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-6
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+44
-67
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+140
-124
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+0
-2
python/sglang/srt/speculative/base_spec_worker.py
python/sglang/srt/speculative/base_spec_worker.py
+29
-0
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+5
-1
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+7
-2
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+3
-3
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+30
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+13
-40
python/sglang/srt/speculative/eagle_worker_v2.py
python/sglang/srt/speculative/eagle_worker_v2.py
+384
-166
python/sglang/srt/speculative/spec_utils.py
python/sglang/srt/speculative/spec_utils.py
+40
-2
python/sglang/srt/speculative/standalone_worker.py
python/sglang/srt/speculative/standalone_worker.py
+2
-12
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
cde5a6e3
...
...
@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
return
req_pool_indices
def
allocate_for_eagle_v2
(
self
):
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
bs
=
self
.
batch_size
()
assert
self
.
spec_info
.
is_draft_input
()
draft_input
:
EagleDraftInput
=
self
.
spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self
.
maybe_wait_verify_done
()
new_allocate_lens
=
self
.
seq_lens
+
EagleDraftInput
.
ALLOC_LEN_PER_DECODE
num_needed_tokens
=
(
new_allocate_lens
-
draft_input
.
allocate_lens
).
sum
().
item
()
out_cache_loc
=
alloc_token_slots
(
self
.
tree_cache
,
num_needed_tokens
)
assign_req_to_token_pool
[(
bs
,)](
self
.
req_pool_indices
,
self
.
req_to_token_pool
.
req_to_token
,
draft_input
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
draft_input
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self
.
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
self
.
seq_lens_sum
=
self
.
seq_lens_cpu
.
sum
().
item
()
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
...
...
@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
bs
=
len
(
self
.
reqs
)
if
self
.
is_v2_eagle
:
# FIXME(lsyin): make this sync optional
self
.
allocate_for_eagle_v2
()
# TODO(spec-v2): all v2 spec should go through this path
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
draft_input
:
EagleDraftInput
=
self
.
spec_info
draft_input
.
prepare_for_decode
(
self
)
if
not
self
.
spec_algorithm
.
is_none
():
# if spec decoding is used, the decode batch is prepared inside
...
...
python/sglang/srt/managers/scheduler.py
View file @
cde5a6e3
...
...
@@ -215,10 +215,10 @@ class GenerationBatchResult:
delay_sample_func
:
Optional
[
callable
]
=
None
future_indices
:
Optional
[
FutureIndices
]
=
None
# FIXME(lsyin): maybe move to
<B
etter
P
lace
>
?
# FIXME(lsyin): maybe move to
a b
etter
p
lace?
# sync path: forward stream -> output processor
accept_lens
:
Optional
[
torch
.
Tensor
]
=
None
last_batch_
allocate_lens
:
Optional
[
torch
.
Tensor
]
=
None
allocate_lens
:
Optional
[
torch
.
Tensor
]
=
None
# relay path: forward stream -> next step forward
next_draft_input
:
Optional
[
EagleDraftInput
]
=
None
...
...
@@ -246,10 +246,8 @@ class GenerationBatchResult:
if
self
.
accept_lens
is
not
None
:
self
.
accept_lens
=
self
.
accept_lens
.
to
(
"cpu"
,
non_blocking
=
True
)
if
self
.
last_batch_allocate_lens
is
not
None
:
self
.
last_batch_allocate_lens
=
self
.
last_batch_allocate_lens
.
to
(
"cpu"
,
non_blocking
=
True
)
if
self
.
allocate_lens
is
not
None
:
self
.
allocate_lens
=
self
.
allocate_lens
.
to
(
"cpu"
,
non_blocking
=
True
)
self
.
copy_done
.
record
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
cde5a6e3
...
...
@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
skip_stream_req
=
None
if
self
.
is_generation
:
if
result
.
copy_done
is
not
None
:
result
.
copy_done
.
synchronize
()
(
logits_output
,
next_token_ids
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
copy_done
,
)
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
copy_done
,
)
if
copy_done
is
not
None
:
copy_done
.
synchronize
()
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
...
...
@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
hacky_process_eagle_overlap_result
(
def
_resolve_spec_overlap_token_ids
(
self
:
Scheduler
,
result
:
GenerationBatchResult
,
batch
:
ScheduleBatch
):
# TODO(lsyin): try use a copy stream to share SMs with forward
# FIXME(lsyin): better organize this token free logic in eagle-overlap
last_batch_allocate_lens_cpu
=
result
.
last_batch_allocate_lens
.
tolist
()
accept_lens_cpu
=
result
.
accept_lens
.
tolist
()
)
->
List
[
List
[
int
]]:
"""Resolve the padding next token ids for speculative decoding with overlap."""
assert
result
.
next_token_ids
.
is_cpu
assert
result
.
accept_lens
.
is_cpu
assert
result
.
allocate_lens
.
is_cpu
next_token_ids
=
result
.
next_token_ids
.
tolist
()
accept_lens
=
result
.
accept_lens
.
tolist
()
result
.
num_accepted_tokens
=
sum
(
accept_lens
)
predict_tokens
=
[]
num_draft_tokens
=
self
.
draft_worker
.
speculative_num_draft_tokens
stride
=
self
.
draft_worker
.
speculative_num_draft_tokens
for
i
,
req
in
enumerate
(
batch
.
reqs
):
predict_tokens
.
append
(
next_token_ids
[
i
*
num_draft_tokens
:
i
*
num_draft_tokens
+
accept_lens_cpu
[
i
]
]
next_token_ids
[
i
*
stride
:
i
*
stride
+
accept_lens
[
i
]]
)
# FIXME(lsyin): move this update elsewhere
req
.
spec_verify_ct
+=
1
return
last_batch_allocate_lens_cpu
,
accept_lens_cpu
,
predict_tokens
return
predict_tokens
def
process_batch_result_decode
(
self
:
Scheduler
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
):
logits_output
,
next_token_ids
,
can_run_cuda_graph
,
copy_done
=
(
if
result
.
copy_done
is
not
None
:
result
.
copy_done
.
synchronize
()
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
result
.
logits_output
,
result
.
next_token_ids
,
result
.
can_run_cuda_graph
,
result
.
copy_done
,
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
copy_done
is
not
None
:
copy_done
.
synchronize
()
if
batch
.
spec_algorithm
.
is_none
():
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
elif
batch
.
is_v2_eagle
:
(
last_batch_allocate_lens_cpu
,
accept_lens_cpu
,
next_token_ids
,
)
=
self
.
hacky_process_eagle_overlap_result
(
result
,
batch
)
result
.
num_accepted_tokens
=
sum
(
accept_lens_cpu
)
next_token_ids
=
self
.
_resolve_spec_overlap_token_ids
(
result
,
batch
)
allocate_lens_list
=
result
.
allocate_lens
.
tolist
()
accept_lens_list
=
result
.
accept_lens
.
tolist
()
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
not
self
.
spec_algorithm
.
is_none
():
self
.
update_spec_metrics
(
batch
.
batch_size
(),
result
.
num_accepted_tokens
)
...
...
@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
continue
if
self
.
enable_overlap
and
req
.
finished
():
indices_to_free
=
None
if
self
.
page_size
==
1
:
if
batch
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_worker_v2
import
(
free_spec_dec_tokens_page_size_1
,
)
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
free_spec_dec_tokens_page_size_1
(
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
,
req
,
last_batch_allocate_lens_cpu
[
i
],
None
,
)
end_p
=
allocate_lens_list
[
i
]
start_p
=
end_p
-
EagleDraftInput
.
ALLOC_LEN_PER_DECODE
indices_to_free
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
start_p
:
end_p
]
else
:
# Free the one extra delayed token
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
]
)
indices_to_free
=
batch
.
out_cache_loc
[
i
:
i
+
1
]
else
:
if
batch
.
spec_algorithm
.
is_eagle
():
# TODO(
lsyin
): support eagle with page_size > 1
# TODO(
spec-v2
): support eagle with page_size > 1
raise
NotImplementedError
()
else
:
if
(
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
)
%
self
.
page_size
==
0
:
# Only free when the extra token is in a new page
self
.
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
]
)
indices_to_free
=
batch
.
out_cache_loc
[
i
:
i
+
1
]
if
indices_to_free
is
not
None
:
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
continue
if
batch
.
spec_algorithm
.
is_none
():
req
.
output_ids
.
append
(
next_token_id
)
elif
batch
.
is_v2_eagle
:
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
# !!!unify the logic here!!!
# Only v2 eagle's output_ids are updated here.
req
.
output_ids
.
extend
(
next_token_id
)
req
.
check_finished
()
...
...
@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
if
batch
.
is_v2_eagle
and
self
.
cur_batch
.
forward_mode
.
is_extend
():
# FIXME(lsyin): fix the messy logic here
# 1) when not overlap (v2 impl), we free the extra tokens in the req
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
from
sglang.srt.speculative.eagle_worker_v2
import
(
free_spec_dec_tokens_page_size_1
,
)
new_seq_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
# FIXME(lsyin): remove this assert
assert
new_seq_len
==
int
(
batch
.
seq_lens_cpu
[
i
]
+
accept_lens_cpu
[
i
]
),
f
"
{
new_seq_len
=
}
vs
{
batch
.
seq_lens_cpu
[
i
]
+
accept_lens_cpu
[
i
]
=
}
"
free_spec_dec_tokens_page_size_1
(
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
,
req
,
last_batch_allocate_lens_cpu
[
i
],
new_seq_len
,
)
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
start_p
=
batch
.
seq_lens_cpu
[
i
]
+
accept_lens_list
[
i
]
end_p
=
allocate_lens_list
[
i
]
indices_to_free
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
start_p
:
end_p
]
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
...
...
python/sglang/srt/managers/tp_worker.py
View file @
cde5a6e3
...
...
@@ -15,6 +15,7 @@
from
__future__
import
annotations
import
logging
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
...
...
@@ -54,7 +55,140 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
class
TpModelWorker
:
class
BaseTpWorker
(
ABC
):
@
abstractmethod
def
forward_batch_generation
(
self
,
forward_batch
:
ForwardBatch
):
pass
@
property
@
abstractmethod
def
model_runner
(
self
)
->
ModelRunner
:
pass
@
property
def
sliding_window_size
(
self
)
->
Optional
[
int
]:
return
self
.
model_runner
.
sliding_window_size
@
property
def
is_hybrid
(
self
)
->
bool
:
return
self
.
model_runner
.
is_hybrid
is
not
None
def
get_tokens_per_layer_info
(
self
):
return
(
self
.
model_runner
.
full_max_total_num_tokens
,
self
.
model_runner
.
swa_max_total_num_tokens
,
)
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_group
(
self
):
return
self
.
model_runner
.
tp_group
def
get_attention_tp_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
def
get_attention_tp_cpu_group
(
self
):
return
getattr
(
self
.
model_runner
.
attention_tp_group
,
"cpu_group"
,
None
)
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool_allocator
,
)
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_disk
(
recv_req
.
model_path
,
recv_req
.
load_format
)
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
init_weights_update_group
(
recv_req
.
master_address
,
recv_req
.
master_port
,
recv_req
.
rank_offset
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
destroy_weights_update_group
(
recv_req
.
group_name
,
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
success
,
message
=
(
self
.
model_runner
.
init_weights_send_group_for_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_rank
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
)
return
success
,
message
def
send_weights_to_remote_instance
(
self
,
recv_req
:
SendWeightsToRemoteInstanceReqInput
):
success
,
message
=
self
.
model_runner
.
send_weights_to_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_name
,
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_distributed
(
recv_req
.
names
,
recv_req
.
dtypes
,
recv_req
.
shapes
,
recv_req
.
group_name
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
monkey_patch_torch_reductions
()
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
named_tensors
=
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
[
self
.
tp_rank
]
),
load_format
=
recv_req
.
load_format
,
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
model_runner
.
get_weights_by_name
(
recv_req
.
name
,
recv_req
.
truncate_size
)
return
parameter
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
load_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
can_run_lora_batch
(
self
,
lora_ids
:
list
[
str
])
->
bool
:
return
self
.
model_runner
.
lora_manager
.
validate_lora_batch
(
lora_ids
)
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
,
_
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
return
embeddings
class
TpModelWorker
(
BaseTpWorker
):
"""A tensor parallel model worker."""
def
__init__
(
...
...
@@ -92,7 +226,7 @@ class TpModelWorker:
is_draft_model
=
is_draft_worker
,
)
self
.
model_runner
=
ModelRunner
(
self
.
_
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
gpu_id
=
gpu_id
,
...
...
@@ -171,6 +305,10 @@ class TpModelWorker:
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
hicache_layer_transfer_counter
=
None
@
property
def
model_runner
(
self
)
->
ModelRunner
:
return
self
.
_model_runner
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
...
...
@@ -193,38 +331,6 @@ class TpModelWorker:
self
.
model_runner
.
token_to_kv_pool
.
size
,
)
@
property
def
sliding_window_size
(
self
)
->
Optional
[
int
]:
return
self
.
model_runner
.
sliding_window_size
@
property
def
is_hybrid
(
self
)
->
bool
:
return
self
.
model_runner
.
is_hybrid
is
not
None
def
get_tokens_per_layer_info
(
self
):
return
(
self
.
model_runner
.
full_max_total_num_tokens
,
self
.
model_runner
.
swa_max_total_num_tokens
,
)
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_group
(
self
):
return
self
.
model_runner
.
tp_group
def
get_attention_tp_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
def
get_attention_tp_cpu_group
(
self
):
return
getattr
(
self
.
model_runner
.
attention_tp_group
,
"cpu_group"
,
None
)
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool_allocator
,
)
def
forward_batch_generation
(
self
,
model_worker_batch
:
ModelWorkerBatch
,
...
...
@@ -313,93 +419,3 @@ class TpModelWorker:
pp_hidden_states_proxy_tensors
=
pp_proxy_tensors
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
,
_
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
return
embeddings
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_disk
(
recv_req
.
model_path
,
recv_req
.
load_format
)
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
init_weights_update_group
(
recv_req
.
master_address
,
recv_req
.
master_port
,
recv_req
.
rank_offset
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
destroy_weights_update_group
(
recv_req
.
group_name
,
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
success
,
message
=
(
self
.
model_runner
.
init_weights_send_group_for_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_rank
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
)
return
success
,
message
def
send_weights_to_remote_instance
(
self
,
recv_req
:
SendWeightsToRemoteInstanceReqInput
):
success
,
message
=
self
.
model_runner
.
send_weights_to_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_name
,
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_distributed
(
recv_req
.
names
,
recv_req
.
dtypes
,
recv_req
.
shapes
,
recv_req
.
group_name
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
monkey_patch_torch_reductions
()
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
named_tensors
=
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
[
self
.
tp_rank
]
),
load_format
=
recv_req
.
load_format
,
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
model_runner
.
get_weights_by_name
(
recv_req
.
name
,
recv_req
.
truncate_size
)
return
parameter
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
load_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
can_run_lora_batch
(
self
,
lora_ids
:
list
[
str
])
->
bool
:
return
self
.
model_runner
.
lora_manager
.
validate_lora_batch
(
lora_ids
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
cde5a6e3
...
...
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
empty_context
,
get_available_gpu_memory
,
get_bool_env_var
,
get_device_memory_capacity
,
is_hip
,
log_info_on_rank0
,
require_attn_tp_gather
,
...
...
@@ -274,7 +273,6 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
...
...
python/sglang/srt/speculative/base_spec_worker.py
0 → 100644
View file @
cde5a6e3
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
sglang.srt.managers.tp_worker
import
TpModelWorker
class
BaseDraftWorker
(
ABC
):
@
abstractmethod
def
draft
():
pass
@
abstractmethod
def
draft_extend
():
pass
class
BaseSpecWorker
(
ABC
):
@
property
@
abstractmethod
def
target_worker
(
self
)
->
TpModelWorker
:
pass
@
property
@
abstractmethod
def
draft_worker
(
self
)
->
BaseDraftWorker
:
pass
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
cde5a6e3
...
...
@@ -40,6 +40,10 @@ class EAGLEDraftCudaGraphRunner:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
# Parse args
self
.
eagle_worker
=
eagle_worker
if
not
hasattr
(
eagle_worker
,
"model_runner"
):
# V2: EagleDraftWorker
self
.
model_runner
=
model_runner
=
eagle_worker
.
draft_runner
else
:
self
.
model_runner
=
model_runner
=
eagle_worker
.
model_runner
self
.
graphs
=
{}
self
.
output_buffers
=
{}
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
cde5a6e3
...
...
@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
# Parse args
self
.
eagle_worker
=
eagle_worker
if
not
hasattr
(
eagle_worker
,
"model_runner"
):
# V2: EagleDraftWorker
self
.
model_runner
=
model_runner
=
eagle_worker
.
draft_runner
else
:
self
.
model_runner
=
model_runner
=
eagle_worker
.
model_runner
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
...
...
@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
hidden_states_backup
=
forward_batch
.
spec_info
.
hidden_states
ret
=
self
.
eagle_worker
.
draft_
model_runner
.
model
.
forward
(
ret
=
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
cde5a6e3
...
...
@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@
dataclass
class
EagleDraftInput
(
SpecInput
,
EagleDraftInputV2Mixin
):
# Constant: alloc length per decode step
ALLOC_LEN_PER_DECODE
:
ClassVar
[
int
]
=
None
# The inputs for decode
# shape: (b, topk)
topk_p
:
torch
.
Tensor
=
None
...
...
@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
# FIXME(lsyin): remove this hack
ALLOC_LEN_PER_DECODE
:
ClassVar
[
int
]
=
None
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
cde5a6e3
...
...
@@ -9,7 +9,8 @@ import triton
import
triton.language
as
tl
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
ScheduleBatch
from
sglang.srt.mem_cache.common
import
alloc_token_slots
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
...
...
@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
@
dataclass
class
EagleDraftInputV2Mixin
:
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
bs
=
batch
.
batch_size
()
# TODO(lsyin): implement over-allocation
# Now seq_lens and allocate_lens are correct
batch
.
maybe_wait_verify_done
()
new_allocate_lens
=
batch
.
seq_lens
+
self
.
ALLOC_LEN_PER_DECODE
num_needed_tokens
=
(
new_allocate_lens
-
self
.
allocate_lens
).
sum
().
item
()
out_cache_loc
=
alloc_token_slots
(
batch
.
tree_cache
,
num_needed_tokens
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
self
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): make this sync optional
batch
.
seq_lens_cpu
=
batch
.
seq_lens
.
cpu
()
batch
.
seq_lens_sum
=
batch
.
seq_lens_cpu
.
sum
().
item
()
def
prepare_for_v2_draft
(
self
:
EagleDraftInput
,
req_to_token_pool
:
ReqToTokenPool
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
cde5a6e3
import
logging
import
os
import
time
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
huggingface_hub
import
snapshot_download
from
sglang.srt.distributed
import
(
GroupCoordinator
,
get_tp_group
,
patch_tensor_parallel_group
,
)
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
...
@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
(
assign_draft_cache_locs
,
detect_nan
,
draft_tp_context
,
fast_topk
,
generate_token_bitmask
,
load_token_map
,
select_top_k_tokens
,
)
from
sglang.srt.utils
import
(
empty_context
,
get_available_gpu_memory
,
get_bool_env_var
,
is_blackwell
,
is_cuda
,
next_power_of_2
,
)
...
...
@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"SGLANG_RETURN_ORIGINAL_LOGPROB"
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
class
EAGLEWorker
(
TpModelWorker
):
def
__init__
(
...
...
@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
padded_static_len
=
-
1
# Override the context length of the draft model to be the same as the target model.
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
...
...
@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
,
skip_attn_backend_init
=
True
)
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
server_args
.
enable_nan_detection
:
detect_nan
(
logits_output
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
topk_p
,
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
if
self
.
hot_token_id
is
not
None
:
...
...
@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
# and will be applied to produce wrong results
batch
.
sampling_info
.
vocab_mask
=
None
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
enable_nan_detection
:
detect_nan
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
batch
,
...
...
@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
)
forward_batch
.
return_logprob
=
False
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
enable_nan_detection
:
detect_nan
(
logits_output
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
...
...
@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
enable_nan_detection
:
detect_nan
(
logits_output
)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
...
...
@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
draft_input
.
topk_p
,
draft_input
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
draft_input
.
hidden_states
=
logits_output
.
hidden_states
def
_detect_nan_if_needed
(
self
,
logits_output
:
LogitsProcessorOutput
):
if
self
.
enable_nan_detection
:
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
if
not
os
.
path
.
exists
(
token_map_path
):
cache_dir
=
snapshot_download
(
os
.
path
.
dirname
(
token_map_path
),
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
],
)
token_map_path
=
os
.
path
.
join
(
cache_dir
,
os
.
path
.
basename
(
token_map_path
))
hot_token_id
=
torch
.
load
(
token_map_path
,
weights_only
=
True
)
return
torch
.
tensor
(
hot_token_id
,
dtype
=
torch
.
int64
)
@
torch
.
compile
(
dynamic
=
True
)
def
get_last_loc_large_page_size_top_k_1
(
...
...
python/sglang/srt/speculative/eagle_worker_v2.py
View file @
cde5a6e3
This diff is collapsed.
Click to expand it.
python/sglang/srt/speculative/spec_utils.py
View file @
cde5a6e3
from
__future__
import
annotations
import
logging
import
os
import
time
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
List
import
torch
import
triton
import
triton.language
as
tl
from
huggingface_hub
import
snapshot_download
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.distributed.parallel_state
import
(
GroupCoordinator
,
patch_tensor_parallel_group
,
)
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.utils
import
is_cuda
,
is_hip
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.allocator
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
if
is_cuda
():
from
sgl_kernel
import
fast_topk
elif
is_hip
():
from
sgl_kernel
import
fast_topk
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -603,3 +615,29 @@ def generate_token_bitmask(
verify_input
.
grammar
=
grammar
return
allocate_token_bitmask
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
if
not
os
.
path
.
exists
(
token_map_path
):
cache_dir
=
snapshot_download
(
os
.
path
.
dirname
(
token_map_path
),
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
],
)
token_map_path
=
os
.
path
.
join
(
cache_dir
,
os
.
path
.
basename
(
token_map_path
))
hot_token_id
=
torch
.
load
(
token_map_path
,
weights_only
=
True
)
return
torch
.
tensor
(
hot_token_id
,
dtype
=
torch
.
int64
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
def
detect_nan
(
logits_output
:
LogitsProcessorOutput
):
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
python/sglang/srt/speculative/standalone_worker.py
View file @
cde5a6e3
import
logging
from
contextlib
import
contextmanager
from
typing
import
Optional
import
torch
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
,
load_token_map
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
draft_tp_context
,
load_token_map
from
sglang.srt.utils
import
empty_context
,
get_bool_env_var
,
is_cuda
if
is_cuda
():
...
...
@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"SGLANG_RETURN_ORIGINAL_LOGPROB"
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
class
StandaloneWorker
(
EAGLEWorker
):
def
__init__
(
...
...
@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
padded_static_len
=
-
1
# Override the context length of the draft model to be the same as the target model.
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment