Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cde5a6e3
Unverified
Commit
cde5a6e3
authored
Oct 17, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 17, 2025
Browse files
Abstraction for spec worker and code cleanup (#11643)
parent
3e4c7da2
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
706 additions
and
460 deletions
+706
-460
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-34
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-6
python/sglang/srt/managers/scheduler_output_processor_mixin.py
...n/sglang/srt/managers/scheduler_output_processor_mixin.py
+44
-67
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+140
-124
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+0
-2
python/sglang/srt/speculative/base_spec_worker.py
python/sglang/srt/speculative/base_spec_worker.py
+29
-0
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+5
-1
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
...g/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
+7
-2
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+3
-3
python/sglang/srt/speculative/eagle_info_v2.py
python/sglang/srt/speculative/eagle_info_v2.py
+30
-1
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+13
-40
python/sglang/srt/speculative/eagle_worker_v2.py
python/sglang/srt/speculative/eagle_worker_v2.py
+384
-166
python/sglang/srt/speculative/spec_utils.py
python/sglang/srt/speculative/spec_utils.py
+40
-2
python/sglang/srt/speculative/standalone_worker.py
python/sglang/srt/speculative/standalone_worker.py
+2
-12
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
cde5a6e3
...
@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1061,38 +1061,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
return
req_pool_indices
return
req_pool_indices
def
allocate_for_eagle_v2
(
self
):
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
bs
=
self
.
batch_size
()
assert
self
.
spec_info
.
is_draft_input
()
draft_input
:
EagleDraftInput
=
self
.
spec_info
# FIXME(lsyin): now implementation does not enable over-allocation
# Now seq_lens and allocate_lens are correct
self
.
maybe_wait_verify_done
()
new_allocate_lens
=
self
.
seq_lens
+
EagleDraftInput
.
ALLOC_LEN_PER_DECODE
num_needed_tokens
=
(
new_allocate_lens
-
draft_input
.
allocate_lens
).
sum
().
item
()
out_cache_loc
=
alloc_token_slots
(
self
.
tree_cache
,
num_needed_tokens
)
assign_req_to_token_pool
[(
bs
,)](
self
.
req_pool_indices
,
self
.
req_to_token_pool
.
req_to_token
,
draft_input
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
draft_input
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): remove seq_lens_sum calculation
self
.
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
self
.
seq_lens_sum
=
self
.
seq_lens_cpu
.
sum
().
item
()
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
self
.
encoder_cached
=
[]
...
@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1522,8 +1490,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
bs
=
len
(
self
.
reqs
)
bs
=
len
(
self
.
reqs
)
if
self
.
is_v2_eagle
:
if
self
.
is_v2_eagle
:
# FIXME(lsyin): make this sync optional
# TODO(spec-v2): all v2 spec should go through this path
self
.
allocate_for_eagle_v2
()
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
draft_input
:
EagleDraftInput
=
self
.
spec_info
draft_input
.
prepare_for_decode
(
self
)
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
# if spec decoding is used, the decode batch is prepared inside
# if spec decoding is used, the decode batch is prepared inside
...
...
python/sglang/srt/managers/scheduler.py
View file @
cde5a6e3
...
@@ -215,10 +215,10 @@ class GenerationBatchResult:
...
@@ -215,10 +215,10 @@ class GenerationBatchResult:
delay_sample_func
:
Optional
[
callable
]
=
None
delay_sample_func
:
Optional
[
callable
]
=
None
future_indices
:
Optional
[
FutureIndices
]
=
None
future_indices
:
Optional
[
FutureIndices
]
=
None
# FIXME(lsyin): maybe move to
<B
etter
P
lace
>
?
# FIXME(lsyin): maybe move to
a b
etter
p
lace?
# sync path: forward stream -> output processor
# sync path: forward stream -> output processor
accept_lens
:
Optional
[
torch
.
Tensor
]
=
None
accept_lens
:
Optional
[
torch
.
Tensor
]
=
None
last_batch_
allocate_lens
:
Optional
[
torch
.
Tensor
]
=
None
allocate_lens
:
Optional
[
torch
.
Tensor
]
=
None
# relay path: forward stream -> next step forward
# relay path: forward stream -> next step forward
next_draft_input
:
Optional
[
EagleDraftInput
]
=
None
next_draft_input
:
Optional
[
EagleDraftInput
]
=
None
...
@@ -246,10 +246,8 @@ class GenerationBatchResult:
...
@@ -246,10 +246,8 @@ class GenerationBatchResult:
if
self
.
accept_lens
is
not
None
:
if
self
.
accept_lens
is
not
None
:
self
.
accept_lens
=
self
.
accept_lens
.
to
(
"cpu"
,
non_blocking
=
True
)
self
.
accept_lens
=
self
.
accept_lens
.
to
(
"cpu"
,
non_blocking
=
True
)
if
self
.
last_batch_allocate_lens
is
not
None
:
if
self
.
allocate_lens
is
not
None
:
self
.
last_batch_allocate_lens
=
self
.
last_batch_allocate_lens
.
to
(
self
.
allocate_lens
=
self
.
allocate_lens
.
to
(
"cpu"
,
non_blocking
=
True
)
"cpu"
,
non_blocking
=
True
)
self
.
copy_done
.
record
()
self
.
copy_done
.
record
()
...
...
python/sglang/srt/managers/scheduler_output_processor_mixin.py
View file @
cde5a6e3
...
@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
...
@@ -42,23 +42,21 @@ class SchedulerOutputProcessorMixin:
skip_stream_req
=
None
skip_stream_req
=
None
if
self
.
is_generation
:
if
self
.
is_generation
:
if
result
.
copy_done
is
not
None
:
result
.
copy_done
.
synchronize
()
(
(
logits_output
,
logits_output
,
next_token_ids
,
next_token_ids
,
extend_input_len_per_req
,
extend_input_len_per_req
,
extend_logprob_start_len_per_req
,
extend_logprob_start_len_per_req
,
copy_done
,
)
=
(
)
=
(
result
.
logits_output
,
result
.
logits_output
,
result
.
next_token_ids
,
result
.
next_token_ids
,
result
.
extend_input_len_per_req
,
result
.
extend_input_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
extend_logprob_start_len_per_req
,
result
.
copy_done
,
)
)
if
copy_done
is
not
None
:
copy_done
.
synchronize
()
# Move next_token_ids and logprobs to cpu
# Move next_token_ids and logprobs to cpu
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
...
@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
...
@@ -199,57 +197,52 @@ class SchedulerOutputProcessorMixin:
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
,
skip_stream_req
)
def
hacky_process_eagle_overlap_result
(
def
_resolve_spec_overlap_token_ids
(
self
:
Scheduler
,
result
:
GenerationBatchResult
,
batch
:
ScheduleBatch
self
:
Scheduler
,
result
:
GenerationBatchResult
,
batch
:
ScheduleBatch
):
)
->
List
[
List
[
int
]]:
# TODO(lsyin): try use a copy stream to share SMs with forward
"""Resolve the padding next token ids for speculative decoding with overlap."""
# FIXME(lsyin): better organize this token free logic in eagle-overlap
assert
result
.
next_token_ids
.
is_cpu
last_batch_allocate_lens_cpu
=
result
.
last_batch_allocate_lens
.
tolist
()
assert
result
.
accept_lens
.
is_cpu
accept_lens_cpu
=
result
.
accept_lens
.
tolist
()
assert
result
.
allocate_lens
.
is_cpu
next_token_ids
=
result
.
next_token_ids
.
tolist
()
next_token_ids
=
result
.
next_token_ids
.
tolist
()
accept_lens
=
result
.
accept_lens
.
tolist
()
result
.
num_accepted_tokens
=
sum
(
accept_lens
)
predict_tokens
=
[]
predict_tokens
=
[]
num_draft_tokens
=
self
.
draft_worker
.
speculative_num_draft_tokens
stride
=
self
.
draft_worker
.
speculative_num_draft_tokens
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
req
in
enumerate
(
batch
.
reqs
):
predict_tokens
.
append
(
predict_tokens
.
append
(
next_token_ids
[
next_token_ids
[
i
*
stride
:
i
*
stride
+
accept_lens
[
i
]]
i
*
num_draft_tokens
:
i
*
num_draft_tokens
+
accept_lens_cpu
[
i
]
]
)
)
# FIXME(lsyin): move this update elsewhere
req
.
spec_verify_ct
+=
1
req
.
spec_verify_ct
+=
1
return
last_batch_allocate_lens_cpu
,
accept_lens_cpu
,
predict_tokens
return
predict_tokens
def
process_batch_result_decode
(
def
process_batch_result_decode
(
self
:
Scheduler
,
self
:
Scheduler
,
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
result
:
GenerationBatchResult
,
result
:
GenerationBatchResult
,
):
):
logits_output
,
next_token_ids
,
can_run_cuda_graph
,
copy_done
=
(
if
result
.
copy_done
is
not
None
:
result
.
copy_done
.
synchronize
()
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
result
.
logits_output
,
result
.
logits_output
,
result
.
next_token_ids
,
result
.
next_token_ids
,
result
.
can_run_cuda_graph
,
result
.
can_run_cuda_graph
,
result
.
copy_done
,
)
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
copy_done
is
not
None
:
copy_done
.
synchronize
()
if
batch
.
spec_algorithm
.
is_none
():
if
batch
.
spec_algorithm
.
is_none
():
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
elif
batch
.
is_v2_eagle
:
elif
batch
.
is_v2_eagle
:
(
next_token_ids
=
self
.
_resolve_spec_overlap_token_ids
(
result
,
batch
)
last_batch_allocate_lens_cpu
,
allocate_lens_list
=
result
.
allocate_lens
.
tolist
()
accept_lens_cpu
,
accept_lens_list
=
result
.
accept_lens
.
tolist
()
next_token_ids
,
)
=
self
.
hacky_process_eagle_overlap_result
(
result
,
batch
)
result
.
num_accepted_tokens
=
sum
(
accept_lens_cpu
)
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
self
.
update_spec_metrics
(
batch
.
batch_size
(),
result
.
num_accepted_tokens
)
self
.
update_spec_metrics
(
batch
.
batch_size
(),
result
.
num_accepted_tokens
)
...
@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
...
@@ -264,43 +257,38 @@ class SchedulerOutputProcessorMixin:
continue
continue
if
self
.
enable_overlap
and
req
.
finished
():
if
self
.
enable_overlap
and
req
.
finished
():
indices_to_free
=
None
if
self
.
page_size
==
1
:
if
self
.
page_size
==
1
:
if
batch
.
spec_algorithm
.
is_eagle
():
if
batch
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_worker_v2
import
(
from
sglang.srt.speculative.eagle_info
import
EagleDraftInput
free_spec_dec_tokens_page_size_1
,
)
free_spec_dec_tokens_page_size_1
(
end_p
=
allocate_lens_list
[
i
]
self
.
req_to_token_pool
,
start_p
=
end_p
-
EagleDraftInput
.
ALLOC_LEN_PER_DECODE
self
.
token_to_kv_pool_allocator
,
indices_to_free
=
self
.
req_to_token_pool
.
req_to_token
[
req
,
req
.
req_pool_idx
last_batch_allocate_lens_cpu
[
i
],
][
start_p
:
end_p
]
None
,
)
else
:
else
:
# Free the one extra delayed token
# Free the one extra delayed token
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
=
batch
.
out_cache_loc
[
i
:
i
+
1
]
batch
.
out_cache_loc
[
i
:
i
+
1
]
)
else
:
else
:
if
batch
.
spec_algorithm
.
is_eagle
():
if
batch
.
spec_algorithm
.
is_eagle
():
# TODO(
lsyin
): support eagle with page_size > 1
# TODO(
spec-v2
): support eagle with page_size > 1
raise
NotImplementedError
()
raise
NotImplementedError
()
else
:
else
:
if
(
if
(
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
)
%
self
.
page_size
==
0
:
)
%
self
.
page_size
==
0
:
# Only free when the extra token is in a new page
# Only free when the extra token is in a new page
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
=
batch
.
out_cache_loc
[
i
:
i
+
1
]
batch
.
out_cache_loc
[
i
:
i
+
1
]
)
if
indices_to_free
is
not
None
:
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
continue
continue
if
batch
.
spec_algorithm
.
is_none
():
if
batch
.
spec_algorithm
.
is_none
():
req
.
output_ids
.
append
(
next_token_id
)
req
.
output_ids
.
append
(
next_token_id
)
elif
batch
.
is_v2_eagle
:
elif
batch
.
is_v2_eagle
:
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
# Only v2 eagle's output_ids are updated here.
# !!!unify the logic here!!!
req
.
output_ids
.
extend
(
next_token_id
)
req
.
output_ids
.
extend
(
next_token_id
)
req
.
check_finished
()
req
.
check_finished
()
...
@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
...
@@ -308,24 +296,13 @@ class SchedulerOutputProcessorMixin:
if
batch
.
is_v2_eagle
and
self
.
cur_batch
.
forward_mode
.
is_extend
():
if
batch
.
is_v2_eagle
and
self
.
cur_batch
.
forward_mode
.
is_extend
():
# FIXME(lsyin): fix the messy logic here
# FIXME(lsyin): fix the messy logic here
# 1) when not overlap (v2 impl), we free the extra tokens in the req
# 1) when not overlap (v2 impl), we free the extra tokens in the req
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
from
sglang.srt.speculative.eagle_worker_v2
import
(
start_p
=
batch
.
seq_lens_cpu
[
i
]
+
accept_lens_list
[
i
]
free_spec_dec_tokens_page_size_1
,
end_p
=
allocate_lens_list
[
i
]
)
indices_to_free
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
new_seq_len
=
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
1
][
start_p
:
end_p
]
# FIXME(lsyin): remove this assert
self
.
token_to_kv_pool_allocator
.
free
(
indices_to_free
)
assert
new_seq_len
==
int
(
batch
.
seq_lens_cpu
[
i
]
+
accept_lens_cpu
[
i
]
),
f
"
{
new_seq_len
=
}
vs
{
batch
.
seq_lens_cpu
[
i
]
+
accept_lens_cpu
[
i
]
=
}
"
free_spec_dec_tokens_page_size_1
(
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
,
req
,
last_batch_allocate_lens_cpu
[
i
],
new_seq_len
,
)
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
if
self
.
server_args
.
disaggregation_decode_enable_offload_kvcache
:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
...
...
python/sglang/srt/managers/tp_worker.py
View file @
cde5a6e3
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
...
@@ -54,7 +55,140 @@ if TYPE_CHECKING:
...
@@ -54,7 +55,140 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
TpModelWorker
:
class
BaseTpWorker
(
ABC
):
@
abstractmethod
def
forward_batch_generation
(
self
,
forward_batch
:
ForwardBatch
):
pass
@
property
@
abstractmethod
def
model_runner
(
self
)
->
ModelRunner
:
pass
@
property
def
sliding_window_size
(
self
)
->
Optional
[
int
]:
return
self
.
model_runner
.
sliding_window_size
@
property
def
is_hybrid
(
self
)
->
bool
:
return
self
.
model_runner
.
is_hybrid
is
not
None
def
get_tokens_per_layer_info
(
self
):
return
(
self
.
model_runner
.
full_max_total_num_tokens
,
self
.
model_runner
.
swa_max_total_num_tokens
,
)
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_group
(
self
):
return
self
.
model_runner
.
tp_group
def
get_attention_tp_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
def
get_attention_tp_cpu_group
(
self
):
return
getattr
(
self
.
model_runner
.
attention_tp_group
,
"cpu_group"
,
None
)
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool_allocator
,
)
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_disk
(
recv_req
.
model_path
,
recv_req
.
load_format
)
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
init_weights_update_group
(
recv_req
.
master_address
,
recv_req
.
master_port
,
recv_req
.
rank_offset
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
destroy_weights_update_group
(
recv_req
.
group_name
,
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
success
,
message
=
(
self
.
model_runner
.
init_weights_send_group_for_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_rank
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
)
return
success
,
message
def
send_weights_to_remote_instance
(
self
,
recv_req
:
SendWeightsToRemoteInstanceReqInput
):
success
,
message
=
self
.
model_runner
.
send_weights_to_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_name
,
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_distributed
(
recv_req
.
names
,
recv_req
.
dtypes
,
recv_req
.
shapes
,
recv_req
.
group_name
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
monkey_patch_torch_reductions
()
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
named_tensors
=
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
[
self
.
tp_rank
]
),
load_format
=
recv_req
.
load_format
,
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
model_runner
.
get_weights_by_name
(
recv_req
.
name
,
recv_req
.
truncate_size
)
return
parameter
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
load_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
can_run_lora_batch
(
self
,
lora_ids
:
list
[
str
])
->
bool
:
return
self
.
model_runner
.
lora_manager
.
validate_lora_batch
(
lora_ids
)
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
,
_
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
return
embeddings
class
TpModelWorker
(
BaseTpWorker
):
"""A tensor parallel model worker."""
"""A tensor parallel model worker."""
def
__init__
(
def
__init__
(
...
@@ -92,7 +226,7 @@ class TpModelWorker:
...
@@ -92,7 +226,7 @@ class TpModelWorker:
is_draft_model
=
is_draft_worker
,
is_draft_model
=
is_draft_worker
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
_
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
...
@@ -171,6 +305,10 @@ class TpModelWorker:
...
@@ -171,6 +305,10 @@ class TpModelWorker:
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
enable_overlap
=
not
server_args
.
disable_overlap_schedule
self
.
hicache_layer_transfer_counter
=
None
self
.
hicache_layer_transfer_counter
=
None
@
property
def
model_runner
(
self
)
->
ModelRunner
:
return
self
.
_model_runner
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
self
.
hicache_layer_transfer_counter
=
counter
...
@@ -193,38 +331,6 @@ class TpModelWorker:
...
@@ -193,38 +331,6 @@ class TpModelWorker:
self
.
model_runner
.
token_to_kv_pool
.
size
,
self
.
model_runner
.
token_to_kv_pool
.
size
,
)
)
@
property
def
sliding_window_size
(
self
)
->
Optional
[
int
]:
return
self
.
model_runner
.
sliding_window_size
@
property
def
is_hybrid
(
self
)
->
bool
:
return
self
.
model_runner
.
is_hybrid
is
not
None
def
get_tokens_per_layer_info
(
self
):
return
(
self
.
model_runner
.
full_max_total_num_tokens
,
self
.
model_runner
.
swa_max_total_num_tokens
,
)
def
get_pad_input_ids_func
(
self
):
return
getattr
(
self
.
model_runner
.
model
,
"pad_input_ids"
,
None
)
def
get_tp_group
(
self
):
return
self
.
model_runner
.
tp_group
def
get_attention_tp_group
(
self
):
return
self
.
model_runner
.
attention_tp_group
def
get_attention_tp_cpu_group
(
self
):
return
getattr
(
self
.
model_runner
.
attention_tp_group
,
"cpu_group"
,
None
)
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool_allocator
,
)
def
forward_batch_generation
(
def
forward_batch_generation
(
self
,
self
,
model_worker_batch
:
ModelWorkerBatch
,
model_worker_batch
:
ModelWorkerBatch
,
...
@@ -313,93 +419,3 @@ class TpModelWorker:
...
@@ -313,93 +419,3 @@ class TpModelWorker:
pp_hidden_states_proxy_tensors
=
pp_proxy_tensors
,
pp_hidden_states_proxy_tensors
=
pp_proxy_tensors
,
can_run_cuda_graph
=
can_run_cuda_graph
,
can_run_cuda_graph
=
can_run_cuda_graph
,
)
)
def
forward_batch_embedding
(
self
,
model_worker_batch
:
ModelWorkerBatch
):
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
,
_
=
self
.
model_runner
.
forward
(
forward_batch
)
embeddings
=
logits_output
.
embeddings
return
embeddings
def
update_weights_from_disk
(
self
,
recv_req
:
UpdateWeightFromDiskReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_disk
(
recv_req
.
model_path
,
recv_req
.
load_format
)
return
success
,
message
def
init_weights_update_group
(
self
,
recv_req
:
InitWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
init_weights_update_group
(
recv_req
.
master_address
,
recv_req
.
master_port
,
recv_req
.
rank_offset
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
destroy_weights_update_group
(
recv_req
.
group_name
,
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
success
,
message
=
(
self
.
model_runner
.
init_weights_send_group_for_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_rank
,
recv_req
.
world_size
,
recv_req
.
group_name
,
recv_req
.
backend
,
)
)
return
success
,
message
def
send_weights_to_remote_instance
(
self
,
recv_req
:
SendWeightsToRemoteInstanceReqInput
):
success
,
message
=
self
.
model_runner
.
send_weights_to_remote_instance
(
recv_req
.
master_address
,
recv_req
.
ports
,
recv_req
.
group_name
,
)
return
success
,
message
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_distributed
(
recv_req
.
names
,
recv_req
.
dtypes
,
recv_req
.
shapes
,
recv_req
.
group_name
)
return
success
,
message
def
update_weights_from_tensor
(
self
,
recv_req
:
UpdateWeightsFromTensorReqInput
):
monkey_patch_torch_reductions
()
success
,
message
=
self
.
model_runner
.
update_weights_from_tensor
(
named_tensors
=
MultiprocessingSerializer
.
deserialize
(
recv_req
.
serialized_named_tensors
[
self
.
tp_rank
]
),
load_format
=
recv_req
.
load_format
,
)
return
success
,
message
def
get_weights_by_name
(
self
,
recv_req
:
GetWeightsByNameReqInput
):
parameter
=
self
.
model_runner
.
get_weights_by_name
(
recv_req
.
name
,
recv_req
.
truncate_size
)
return
parameter
def
load_lora_adapter
(
self
,
recv_req
:
LoadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
load_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
unload_lora_adapter
(
self
,
recv_req
:
UnloadLoRAAdapterReqInput
):
result
=
self
.
model_runner
.
unload_lora_adapter
(
recv_req
.
to_ref
())
return
result
def
can_run_lora_batch
(
self
,
lora_ids
:
list
[
str
])
->
bool
:
return
self
.
model_runner
.
lora_manager
.
validate_lora_batch
(
lora_ids
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
cde5a6e3
...
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
...
@@ -53,7 +53,6 @@ from sglang.srt.utils import (
empty_context
,
empty_context
,
get_available_gpu_memory
,
get_available_gpu_memory
,
get_bool_env_var
,
get_bool_env_var
,
get_device_memory_capacity
,
is_hip
,
is_hip
,
log_info_on_rank0
,
log_info_on_rank0
,
require_attn_tp_gather
,
require_attn_tp_gather
,
...
@@ -274,7 +273,6 @@ class CudaGraphRunner:
...
@@ -274,7 +273,6 @@ class CudaGraphRunner:
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self
.
encoder_len_fill_value
=
0
self
.
encoder_len_fill_value
=
0
self
.
seq_lens_cpu
=
torch
.
full
(
self
.
seq_lens_cpu
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
...
...
python/sglang/srt/speculative/base_spec_worker.py
0 → 100644
View file @
cde5a6e3
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
sglang.srt.managers.tp_worker
import
TpModelWorker
class
BaseDraftWorker
(
ABC
):
@
abstractmethod
def
draft
():
pass
@
abstractmethod
def
draft_extend
():
pass
class
BaseSpecWorker
(
ABC
):
@
property
@
abstractmethod
def
target_worker
(
self
)
->
TpModelWorker
:
pass
@
property
@
abstractmethod
def
draft_worker
(
self
)
->
BaseDraftWorker
:
pass
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
cde5a6e3
...
@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
...
@@ -40,7 +40,11 @@ class EAGLEDraftCudaGraphRunner:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
# Parse args
# Parse args
self
.
eagle_worker
=
eagle_worker
self
.
eagle_worker
=
eagle_worker
self
.
model_runner
=
model_runner
=
eagle_worker
.
model_runner
if
not
hasattr
(
eagle_worker
,
"model_runner"
):
# V2: EagleDraftWorker
self
.
model_runner
=
model_runner
=
eagle_worker
.
draft_runner
else
:
self
.
model_runner
=
model_runner
=
eagle_worker
.
model_runner
self
.
graphs
=
{}
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
...
...
python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py
View file @
cde5a6e3
...
@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -38,7 +38,12 @@ class EAGLEDraftExtendCudaGraphRunner:
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
def
__init__
(
self
,
eagle_worker
:
EAGLEWorker
):
# Parse args
# Parse args
self
.
eagle_worker
=
eagle_worker
self
.
eagle_worker
=
eagle_worker
self
.
model_runner
=
model_runner
=
eagle_worker
.
model_runner
if
not
hasattr
(
eagle_worker
,
"model_runner"
):
# V2: EagleDraftWorker
self
.
model_runner
=
model_runner
=
eagle_worker
.
draft_runner
else
:
self
.
model_runner
=
model_runner
=
eagle_worker
.
model_runner
self
.
graphs
=
{}
self
.
graphs
=
{}
self
.
output_buffers
=
{}
self
.
output_buffers
=
{}
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
self
.
enable_torch_compile
=
model_runner
.
server_args
.
enable_torch_compile
...
@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
...
@@ -285,7 +290,7 @@ class EAGLEDraftExtendCudaGraphRunner:
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
output_cache_loc_backup
=
forward_batch
.
out_cache_loc
hidden_states_backup
=
forward_batch
.
spec_info
.
hidden_states
hidden_states_backup
=
forward_batch
.
spec_info
.
hidden_states
ret
=
self
.
eagle_worker
.
draft_
model_runner
.
model
.
forward
(
ret
=
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
.
positions
,
forward_batch
,
forward_batch
,
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
cde5a6e3
...
@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
...
@@ -574,6 +574,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
@
dataclass
@
dataclass
class
EagleDraftInput
(
SpecInput
,
EagleDraftInputV2Mixin
):
class
EagleDraftInput
(
SpecInput
,
EagleDraftInputV2Mixin
):
# Constant: alloc length per decode step
ALLOC_LEN_PER_DECODE
:
ClassVar
[
int
]
=
None
# The inputs for decode
# The inputs for decode
# shape: (b, topk)
# shape: (b, topk)
topk_p
:
torch
.
Tensor
=
None
topk_p
:
torch
.
Tensor
=
None
...
@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
...
@@ -609,9 +612,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
# FIXME(lsyin): remove this hack
ALLOC_LEN_PER_DECODE
:
ClassVar
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
...
python/sglang/srt/speculative/eagle_info_v2.py
View file @
cde5a6e3
...
@@ -9,7 +9,8 @@ import triton
...
@@ -9,7 +9,8 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
ScheduleBatch
from
sglang.srt.mem_cache.common
import
alloc_token_slots
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
CaptureHiddenMode
,
...
@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
...
@@ -72,6 +73,34 @@ def assign_draft_cache_locs_page_size_1(
@
dataclass
@
dataclass
class
EagleDraftInputV2Mixin
:
class
EagleDraftInputV2Mixin
:
def
prepare_for_decode
(
self
:
EagleDraftInput
,
batch
:
ScheduleBatch
):
from
sglang.srt.speculative.spec_utils
import
assign_req_to_token_pool
bs
=
batch
.
batch_size
()
# TODO(lsyin): implement over-allocation
# Now seq_lens and allocate_lens are correct
batch
.
maybe_wait_verify_done
()
new_allocate_lens
=
batch
.
seq_lens
+
self
.
ALLOC_LEN_PER_DECODE
num_needed_tokens
=
(
new_allocate_lens
-
self
.
allocate_lens
).
sum
().
item
()
out_cache_loc
=
alloc_token_slots
(
batch
.
tree_cache
,
num_needed_tokens
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
self
.
allocate_lens
,
new_allocate_lens
,
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
next_power_of_2
(
bs
),
)
self
.
allocate_lens
=
new_allocate_lens
# FIXME(lsyin): make this sync optional
batch
.
seq_lens_cpu
=
batch
.
seq_lens
.
cpu
()
batch
.
seq_lens_sum
=
batch
.
seq_lens_cpu
.
sum
().
item
()
def
prepare_for_v2_draft
(
def
prepare_for_v2_draft
(
self
:
EagleDraftInput
,
self
:
EagleDraftInput
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
cde5a6e3
import
logging
import
logging
import
os
import
time
import
time
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
huggingface_hub
import
snapshot_download
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
get_tp_group
GroupCoordinator
,
get_tp_group
,
patch_tensor_parallel_group
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
...
@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
...
@@ -47,15 +40,17 @@ from sglang.srt.speculative.eagle_utils import (
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
(
from
sglang.srt.speculative.spec_utils
import
(
assign_draft_cache_locs
,
assign_draft_cache_locs
,
detect_nan
,
draft_tp_context
,
fast_topk
,
fast_topk
,
generate_token_bitmask
,
generate_token_bitmask
,
load_token_map
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
empty_context
,
empty_context
,
get_available_gpu_memory
,
get_available_gpu_memory
,
get_bool_env_var
,
get_bool_env_var
,
is_blackwell
,
is_cuda
,
is_cuda
,
next_power_of_2
,
next_power_of_2
,
)
)
...
@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
...
@@ -67,14 +62,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"SGLANG_RETURN_ORIGINAL_LOGPROB"
)
SGLANG_RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"SGLANG_RETURN_ORIGINAL_LOGPROB"
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
class
EAGLEWorker
(
TpModelWorker
):
class
EAGLEWorker
(
TpModelWorker
):
def
__init__
(
def
__init__
(
...
@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -100,7 +87,6 @@ class EAGLEWorker(TpModelWorker):
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
self
.
padded_static_len
=
-
1
# Override the context length of the draft model to be the same as the target model.
# Override the context length of the draft model to be the same as the target model.
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
...
@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -612,7 +598,8 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
,
skip_attn_backend_init
=
True
forward_batch
,
skip_attn_backend_init
=
True
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
server_args
.
enable_nan_detection
:
detect_nan
(
logits_output
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
topk_p
,
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
topk_p
,
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
if
self
.
hot_token_id
is
not
None
:
if
self
.
hot_token_id
is
not
None
:
...
@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -680,7 +667,9 @@ class EAGLEWorker(TpModelWorker):
# and will be applied to produce wrong results
# and will be applied to produce wrong results
batch
.
sampling_info
.
vocab_mask
=
None
batch
.
sampling_info
.
vocab_mask
=
None
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
enable_nan_detection
:
detect_nan
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
hidden_states
=
logits_output
.
hidden_states
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
batch
,
batch
,
...
@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -833,7 +822,8 @@ class EAGLEWorker(TpModelWorker):
)
)
forward_batch
.
return_logprob
=
False
forward_batch
.
return_logprob
=
False
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
logits_output
,
_
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
enable_nan_detection
:
detect_nan
(
logits_output
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
...
@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -928,7 +918,8 @@ class EAGLEWorker(TpModelWorker):
)
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
self
.
_detect_nan_if_needed
(
logits_output
)
if
self
.
enable_nan_detection
:
detect_nan
(
logits_output
)
# Restore backup.
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
...
@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -948,24 +939,6 @@ class EAGLEWorker(TpModelWorker):
draft_input
.
topk_p
,
draft_input
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
draft_input
.
topk_p
,
draft_input
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
draft_input
.
hidden_states
=
logits_output
.
hidden_states
draft_input
.
hidden_states
=
logits_output
.
hidden_states
def
_detect_nan_if_needed
(
self
,
logits_output
:
LogitsProcessorOutput
):
if
self
.
enable_nan_detection
:
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
if
not
os
.
path
.
exists
(
token_map_path
):
cache_dir
=
snapshot_download
(
os
.
path
.
dirname
(
token_map_path
),
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
],
)
token_map_path
=
os
.
path
.
join
(
cache_dir
,
os
.
path
.
basename
(
token_map_path
))
hot_token_id
=
torch
.
load
(
token_map_path
,
weights_only
=
True
)
return
torch
.
tensor
(
hot_token_id
,
dtype
=
torch
.
int64
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
)
def
get_last_loc_large_page_size_top_k_1
(
def
get_last_loc_large_page_size_top_k_1
(
...
...
python/sglang/srt/speculative/eagle_worker_v2.py
View file @
cde5a6e3
This diff is collapsed.
Click to expand it.
python/sglang/srt/speculative/spec_utils.py
View file @
cde5a6e3
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
import
os
import
time
import
time
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
List
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
huggingface_hub
import
snapshot_download
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.distributed.parallel_state
import
(
GroupCoordinator
,
patch_tensor_parallel_group
,
)
from
sglang.srt.environ
import
envs
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.utils
import
is_cuda
,
is_hip
from
sglang.srt.utils
import
is_cuda
,
is_hip
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.allocator
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
fast_topk
from
sgl_kernel
import
fast_topk
elif
is_hip
():
elif
is_hip
():
from
sgl_kernel
import
fast_topk
from
sgl_kernel
import
fast_topk
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_info
import
EagleVerifyInput
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -603,3 +615,29 @@ def generate_token_bitmask(
...
@@ -603,3 +615,29 @@ def generate_token_bitmask(
verify_input
.
grammar
=
grammar
verify_input
.
grammar
=
grammar
return
allocate_token_bitmask
return
allocate_token_bitmask
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
if
not
os
.
path
.
exists
(
token_map_path
):
cache_dir
=
snapshot_download
(
os
.
path
.
dirname
(
token_map_path
),
ignore_patterns
=
[
"*.bin"
,
"*.safetensors"
],
)
token_map_path
=
os
.
path
.
join
(
cache_dir
,
os
.
path
.
basename
(
token_map_path
))
hot_token_id
=
torch
.
load
(
token_map_path
,
weights_only
=
True
)
return
torch
.
tensor
(
hot_token_id
,
dtype
=
torch
.
int64
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
def
detect_nan
(
logits_output
:
LogitsProcessorOutput
):
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
error
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
python/sglang/srt/speculative/standalone_worker.py
View file @
cde5a6e3
import
logging
import
logging
from
contextlib
import
contextmanager
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
sglang.srt.distributed
import
GroupCoordinator
,
patch_tensor_parallel_group
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
,
load_token_map
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_utils
import
draft_tp_context
,
load_token_map
from
sglang.srt.utils
import
empty_context
,
get_bool_env_var
,
is_cuda
from
sglang.srt.utils
import
empty_context
,
get_bool_env_var
,
is_cuda
if
is_cuda
():
if
is_cuda
():
...
@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
...
@@ -18,14 +17,6 @@ logger = logging.getLogger(__name__)
SGLANG_RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"SGLANG_RETURN_ORIGINAL_LOGPROB"
)
SGLANG_RETURN_ORIGINAL_LOGPROB
=
get_bool_env_var
(
"SGLANG_RETURN_ORIGINAL_LOGPROB"
)
@
contextmanager
def
draft_tp_context
(
tp_group
:
GroupCoordinator
):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with
patch_tensor_parallel_group
(
tp_group
):
yield
class
StandaloneWorker
(
EAGLEWorker
):
class
StandaloneWorker
(
EAGLEWorker
):
def
__init__
(
def
__init__
(
...
@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
...
@@ -51,7 +42,6 @@ class StandaloneWorker(EAGLEWorker):
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
self
.
padded_static_len
=
-
1
# Override the context length of the draft model to be the same as the target model.
# Override the context length of the draft model to be the same as the target model.
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment