Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d3d4d767
Unverified
Commit
d3d4d767
authored
Mar 05, 2025
by
Ying Sheng
Committed by
GitHub
Mar 05, 2025
Browse files
[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by:
Ke Bao
<
ISPObaoke@163.com
>
parent
5be8f1ed
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
529 additions
and
349 deletions
+529
-349
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+81
-58
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-0
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+26
-24
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+19
-14
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+31
-26
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+0
-17
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+6
-1
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+9
-5
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+4
-4
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+41
-18
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+11
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+39
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-3
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+2
-8
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+0
-1
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+92
-58
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+154
-93
No files found.
python/sglang/bench_one_batch.py
View file @
d3d4d767
...
@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
...
@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
batch
=
ScheduleBatch
.
init_new
(
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
model_runner
.
token_to_kv_pool
_allocator
,
tree_cache
=
None
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
enable_overlap
=
False
,
...
@@ -326,7 +326,7 @@ def latency_test_run_once(
...
@@ -326,7 +326,7 @@ def latency_test_run_once(
# Clear the pools.
# Clear the pools.
model_runner
.
req_to_token_pool
.
clear
()
model_runner
.
req_to_token_pool
.
clear
()
model_runner
.
token_to_kv_pool
.
clear
()
model_runner
.
token_to_kv_pool
_allocator
.
clear
()
measurement_results
=
{
measurement_results
=
{
"run_name"
:
run_name
,
"run_name"
:
run_name
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
d3d4d767
...
@@ -20,14 +20,15 @@ import triton.language as tl
...
@@ -20,14 +20,15 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
(
from
flashinfer
import
(
...
@@ -36,6 +37,7 @@ if is_flashinfer_available():
...
@@ -36,6 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.cascade
import
merge_state
from
flashinfer.decode
import
PosEncodingMode
class
WrapperDispatch
(
Enum
):
class
WrapperDispatch
(
Enum
):
...
@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
device
=
model_runner
.
device
,
device
=
model_runner
.
device
,
)
)
self
.
workspace_buffer
=
global_workspace_buffer
self
.
workspace_buffer
=
global_workspace_buffer
max_bs
=
model_runner
.
req_to_token_pool
.
size
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
[
self
.
kv_indptr
=
[
...
@@ -133,8 +136,11 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -133,8 +136,11 @@ class FlashInferAttnBackend(AttentionBackend):
assert
self
.
num_wrappers
==
1
assert
self
.
num_wrappers
==
1
self
.
kv_last_page_len
=
kv_last_page_len_buf
self
.
kv_last_page_len
=
kv_last_page_len_buf
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
[
self
.
qo_indptr
=
[
torch
.
zeros
((
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
for
_
in
range
(
self
.
num_wrappers
)
]
]
...
@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
decode_wrappers
=
[]
decode_wrappers
=
[]
...
@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
if
forward_mode
.
is_decode_or_idle
():
if
forward_mode
.
is_decode_or_idle
():
self
.
indices_updater_decode
.
update
(
self
.
indices_updater_decode
.
update
(
...
@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
self
.
call_begin_forward
(
...
@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
if
spec_info
is
None
:
if
spec_info
is
None
:
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
self
.
req_to_token
.
shape
[
1
],
self
.
req_to_token
.
shape
[
1
],
)
)
else
:
else
:
assert
isinstance
(
spec_info
,
EagleDraftInput
)
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
wrapper
.
begin_forward
(
wrapper
.
begin_forward
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
...
@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
if
use_ragged
:
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
...
@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
use_ragged
:
bool
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
):
bs
=
len
(
r
eq_
pool_indice
s
)
bs
=
len
(
s
eq_
len
s
)
if
spec_info
is
None
:
if
spec_info
is
None
:
assert
len
(
seq_lens
)
==
len
(
req_pool_indices
)
# Normal extend
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
...
@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
custom_mask
=
None
else
:
else
:
assert
isinstance
(
spec_info
,
EagleDraftInput
)
or
isinstance
(
spec_info
,
EagleVerifyInput
)
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
req_pool_indices
,
paged_kernel_lens
,
paged_kernel_lens
,
paged_kernel_lens_sum
,
self
.
req_to_token
,
self
.
req_to_token
,
)
)
)
)
...
@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
)
)
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global
global_override_indptr_cpu
class
FlashInferMultiStepDraftBackend
:
class
FlashInferMultiStepDraftBackend
:
"""
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
Wrap multiple flashinfer attention backends as one for multiple consecutive
...
@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
(
(
...
@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
...
@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
kv_last_page_len_buf
=
self
.
kv_last_page_len
,
kv_last_page_len_buf
=
self
.
kv_last_page_len
,
)
)
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
# Cached variables for generate_draft_decode_kv_indices
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
...
@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
triton
.
next_power_of_2
(
bs
),
triton
.
next_power_of_2
(
bs
),
)
)
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole
=
self
.
kv_indptr
[:,
:
bs
+
1
].
cpu
()
global
global_override_indptr_cpu
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
]
]
global_override_indptr_cpu
=
indptr_cpu_whole
[
i
]
call_fn
(
i
,
forward_batch
)
call_fn
(
i
,
forward_batch
)
global_override_indptr_cpu
=
None
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
kv_indices
=
torch
.
zeros
(
(
(
...
@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
...
@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
)
)
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
)
...
@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
for
i
in
range
(
self
.
speculative_num_steps
):
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
max_bs
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
...
@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
kv_start
+
offset
,
mask
=
mask
,
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
def
should_use_tensor_core
(
def
should_use_tensor_core
(
kv_cache_dtype
:
torch
.
dtype
,
kv_cache_dtype
:
torch
.
dtype
,
num_attention_heads
:
int
,
num_attention_heads
:
int
,
...
@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
...
@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
if
env_override
is
not
None
:
if
env_override
is
not
None
:
return
env_override
.
lower
()
==
"true"
return
env_override
.
lower
()
==
"true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
,
):
return
True
else
:
return
False
except
(
ImportError
,
AttributeError
):
pass
# Calculate GQA group size
# Calculate GQA group size
gqa_group_size
=
num_attention_heads
//
num_kv_heads
gqa_group_size
=
num_attention_heads
//
num_kv_heads
...
@@ -1118,12 +1128,18 @@ def fast_decode_plan(
...
@@ -1118,12 +1128,18 @@ def fast_decode_plan(
sm_scale
:
Optional
[
float
]
=
None
,
sm_scale
:
Optional
[
float
]
=
None
,
rope_scale
:
Optional
[
float
]
=
None
,
rope_scale
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
**
kwargs
,
non_blocking
:
bool
=
True
,
)
->
None
:
)
->
None
:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size
=
len
(
last_page_len
)
batch_size
=
len
(
last_page_len
)
if
logits_soft_cap
is
None
:
if
logits_soft_cap
is
None
:
logits_soft_cap
=
0.0
logits_soft_cap
=
0.0
if
self
.
is_cuda_graph_enabled
:
if
self
.
is_cuda_graph_enabled
:
if
batch_size
!=
self
.
_fixed_batch_size
:
if
batch_size
!=
self
.
_fixed_batch_size
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1136,13 +1152,19 @@ def fast_decode_plan(
...
@@ -1136,13 +1152,19 @@ def fast_decode_plan(
raise
ValueError
(
raise
ValueError
(
"The size of indices should be less than or equal to the allocated buffer"
"The size of indices should be less than or equal to the allocated buffer"
)
)
# Skip these copies
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else
:
else
:
self
.
_paged_kv_indptr_buf
=
indptr
self
.
_paged_kv_indptr_buf
=
indptr
self
.
_paged_kv_indices_buf
=
indices
self
.
_paged_kv_indices_buf
=
indices
self
.
_paged_kv_last_page_len_buf
=
last_page_len
self
.
_paged_kv_last_page_len_buf
=
last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if
not
q_data_type
:
if
not
q_data_type
:
q_data_type
=
data_type
q_data_type
=
data_type
if
not
hasattr
(
self
,
"empty_q_data"
):
if
not
hasattr
(
self
,
"empty_q_data"
):
self
.
empty_q_data
=
torch
.
empty
(
self
.
empty_q_data
=
torch
.
empty
(
0
,
0
,
...
@@ -1159,6 +1181,7 @@ def fast_decode_plan(
...
@@ -1159,6 +1181,7 @@ def fast_decode_plan(
),
),
)
)
self
.
last_page_len
=
torch
.
ones
(
32768
,
dtype
=
torch
.
int32
)
self
.
last_page_len
=
torch
.
ones
(
32768
,
dtype
=
torch
.
int32
)
empty_q_data
=
self
.
empty_q_data
empty_q_data
=
self
.
empty_q_data
empty_kv_cache
=
self
.
empty_kv_cache
empty_kv_cache
=
self
.
empty_kv_cache
stream
=
torch
.
cuda
.
current_stream
()
stream
=
torch
.
cuda
.
current_stream
()
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
d3d4d767
...
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
None
,
self
.
req_to_token
,
self
.
req_to_token
,
)
)
)
)
...
...
python/sglang/srt/managers/cache_controller.py
View file @
d3d4d767
...
@@ -22,7 +22,7 @@ from typing import List, Optional
...
@@ -22,7 +22,7 @@ from typing import List, Optional
import
torch
import
torch
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
M
L
ATokenToKVPoolHost
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
M
H
ATokenToKVPoolHost
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -128,7 +128,7 @@ class HiCacheController:
...
@@ -128,7 +128,7 @@ class HiCacheController:
def
__init__
(
def
__init__
(
self
,
self
,
mem_pool_device
:
MHATokenToKVPool
,
mem_pool_device
:
MHATokenToKVPool
,
mem_pool_host
:
M
L
ATokenToKVPoolHost
,
mem_pool_host
:
M
H
ATokenToKVPoolHost
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
):
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
d3d4d767
...
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
...
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
# Put some global args for easy access
...
@@ -523,7 +521,7 @@ class ScheduleBatch:
...
@@ -523,7 +521,7 @@ class ScheduleBatch:
# Request, memory pool, and cache
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
=
None
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
Base
TokenToKVPool
=
None
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
=
None
tree_cache
:
BasePrefixCache
=
None
tree_cache
:
BasePrefixCache
=
None
# Batch configs
# Batch configs
...
@@ -596,7 +594,7 @@ class ScheduleBatch:
...
@@ -596,7 +594,7 @@ class ScheduleBatch:
cls
,
cls
,
reqs
:
List
[
Req
],
reqs
:
List
[
Req
],
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
ReqToTokenPool
,
token_to_kv_pool
_allocator
:
TokenToKVPoolAllocator
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
enable_overlap
:
bool
,
...
@@ -606,7 +604,7 @@ class ScheduleBatch:
...
@@ -606,7 +604,7 @@ class ScheduleBatch:
return
cls
(
return
cls
(
reqs
=
reqs
,
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
,
tree_cache
=
tree_cache
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
model_config
=
model_config
,
enable_overlap
=
enable_overlap
,
enable_overlap
=
enable_overlap
,
...
@@ -637,19 +635,19 @@ class ScheduleBatch:
...
@@ -637,19 +635,19 @@ class ScheduleBatch:
return
req_pool_indices
return
req_pool_indices
def
alloc_token_slots
(
self
,
num_tokens
:
int
):
def
alloc_token_slots
(
self
,
num_tokens
:
int
):
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
_allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
if
self
.
tree_cache
is
not
None
:
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
evict
(
num_tokens
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
num_tokens
,
self
.
token_to_kv_pool
_allocator
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
_allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
phase_str
=
"Prefill"
if
self
.
forward_mode
.
is_extend
()
else
"Decode"
phase_str
=
"Prefill"
if
self
.
forward_mode
.
is_extend
()
else
"Decode"
logger
.
error
(
logger
.
error
(
f
"
{
phase_str
}
out of memory. Try to lower your batch size.
\n
"
f
"
{
phase_str
}
out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
num_tokens
}
tokens.
\n
"
f
"Try to allocate
{
num_tokens
}
tokens.
\n
"
f
"Avaliable tokens:
{
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
}
\n
"
f
"Avaliable tokens:
{
self
.
token_to_kv_pool
_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
}
\n
"
)
)
if
self
.
tree_cache
is
not
None
:
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
pretty_print
()
self
.
tree_cache
.
pretty_print
()
...
@@ -917,12 +915,12 @@ class ScheduleBatch:
...
@@ -917,12 +915,12 @@ class ScheduleBatch:
def
check_decode_mem
(
self
,
buf_multiplier
=
1
):
def
check_decode_mem
(
self
,
buf_multiplier
=
1
):
bs
=
len
(
self
.
reqs
)
*
buf_multiplier
bs
=
len
(
self
.
reqs
)
*
buf_multiplier
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
_allocator
.
available_size
()
>=
bs
:
return
True
return
True
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
_allocator
.
free
)
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
_allocator
.
available_size
()
>=
bs
:
return
True
return
True
return
False
return
False
...
@@ -945,6 +943,10 @@ class ScheduleBatch:
...
@@ -945,6 +943,10 @@ class ScheduleBatch:
reverse
=
True
,
reverse
=
True
,
)
)
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
first_iter
=
True
def
get_required_tokens
(
num_reqs
:
int
):
def
get_required_tokens
(
num_reqs
:
int
):
headroom_for_spec_decode
=
0
headroom_for_spec_decode
=
0
if
server_args
.
speculative_algorithm
:
if
server_args
.
speculative_algorithm
:
...
@@ -958,18 +960,15 @@ class ScheduleBatch:
...
@@ -958,18 +960,15 @@ class ScheduleBatch:
num_reqs
*
global_config
.
retract_decode_steps
+
headroom_for_spec_decode
num_reqs
*
global_config
.
retract_decode_steps
+
headroom_for_spec_decode
)
)
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
first_iter
=
True
while
(
while
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
_allocator
.
available_size
()
<
get_required_tokens
(
len
(
sorted_indices
))
<
get_required_tokens
(
len
(
sorted_indices
))
or
first_iter
or
first_iter
):
):
if
len
(
sorted_indices
)
==
1
:
if
len
(
sorted_indices
)
==
1
:
# Corner case: only one request left
# Corner case: only one request left
assert
(
assert
(
self
.
token_to_kv_pool
.
available_size
()
>
0
self
.
token_to_kv_pool
_allocator
.
available_size
()
>
0
),
"No space left for only one request"
),
"No space left for only one request"
break
break
...
@@ -983,7 +982,7 @@ class ScheduleBatch:
...
@@ -983,7 +982,7 @@ class ScheduleBatch:
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
seq_lens_cpu
[
idx
]
req
.
req_pool_idx
,
:
seq_lens_cpu
[
idx
]
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
del
self
.
tree_cache
.
entries
[
req
.
rid
]
del
self
.
tree_cache
.
entries
[
req
.
rid
]
else
:
else
:
...
@@ -992,7 +991,7 @@ class ScheduleBatch:
...
@@ -992,7 +991,7 @@ class ScheduleBatch:
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
# release the last node
# release the last node
...
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
...
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
# NOTE(lsyin): we should use the newly evictable memory instantly.
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size
=
(
residual_size
=
(
len
(
sorted_indices
)
*
global_config
.
retract_decode_steps
len
(
sorted_indices
)
*
global_config
.
retract_decode_steps
-
self
.
token_to_kv_pool
.
available_size
()
-
self
.
token_to_kv_pool
_allocator
.
available_size
()
)
)
residual_size
=
max
(
0
,
residual_size
)
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool_allocator
.
free
)
req
.
reset_for_retract
()
req
.
reset_for_retract
()
self
.
filter_batch
(
keep_indices
=
sorted_indices
)
self
.
filter_batch
(
keep_indices
=
sorted_indices
)
...
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
...
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
if
self
.
spec_info
:
if
self
.
spec_info
:
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
def
get_model_worker_batch
(
self
):
def
get_model_worker_batch
(
self
)
->
ModelWorkerBatch
:
if
self
.
forward_mode
.
is_decode_or_idle
():
if
self
.
forward_mode
.
is_decode_or_idle
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
else
:
else
:
...
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
...
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
req_pool_indices
:
torch
.
Tensor
req_pool_indices
:
torch
.
Tensor
# The sequence length
# The sequence length
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool
_allocator
out_cache_loc
:
torch
.
Tensor
out_cache_loc
:
torch
.
Tensor
# The sum of all sequence lengths
# The sum of all sequence lengths
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
d3d4d767
...
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
...
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
import
torch
import
torch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
(
Req
,
ScheduleBatch
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
Base
TokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPool
Allocator
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
...
@@ -75,7 +79,7 @@ class SchedulePolicy:
...
@@ -75,7 +79,7 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
# It is used to find the matching prefix for in-batch prefix caching.
self
.
waiting_queue_radix_tree
=
RadixCache
(
self
.
waiting_queue_radix_tree
=
RadixCache
(
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
req_to_token_pool
=
None
,
token_to_kv_pool
_allocator
=
None
,
disable
=
False
)
)
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
...
@@ -251,7 +255,7 @@ class PrefillAdder:
...
@@ -251,7 +255,7 @@ class PrefillAdder:
def
__init__
(
def
__init__
(
self
,
self
,
tree_cache
:
BasePrefixCache
,
tree_cache
:
BasePrefixCache
,
token_to_kv_pool
:
Base
TokenToKVPool
,
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
,
running_batch
:
ScheduleBatch
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
,
new_token_ratio
:
float
,
rem_input_tokens
:
int
,
rem_input_tokens
:
int
,
...
@@ -259,7 +263,7 @@ class PrefillAdder:
...
@@ -259,7 +263,7 @@ class PrefillAdder:
mixed_with_decode_tokens
:
int
=
0
,
mixed_with_decode_tokens
:
int
=
0
,
):
):
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
self
.
running_batch
=
running_batch
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
...
@@ -291,7 +295,7 @@ class PrefillAdder:
...
@@ -291,7 +295,7 @@ class PrefillAdder:
@
property
@
property
def
rem_total_tokens
(
self
):
def
rem_total_tokens
(
self
):
return
(
return
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
-
self
.
rem_total_token_offset
-
self
.
rem_total_token_offset
)
)
...
@@ -299,7 +303,7 @@ class PrefillAdder:
...
@@ -299,7 +303,7 @@ class PrefillAdder:
@
property
@
property
def
cur_rem_tokens
(
self
):
def
cur_rem_tokens
(
self
):
return
(
return
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
+
self
.
tree_cache
.
evictable_size
()
-
self
.
cur_rem_token_offset
-
self
.
cur_rem_token_offset
)
)
...
@@ -332,7 +336,6 @@ class PrefillAdder:
...
@@ -332,7 +336,6 @@ class PrefillAdder:
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
self
.
_prefill_one_req
(
0
,
0
,
req
.
extend_input_len
,
req
.
extend_input_len
,
...
@@ -400,8 +403,8 @@ class PrefillAdder:
...
@@ -400,8 +403,8 @@ class PrefillAdder:
tokens_freed
+=
tokens_occupied
tokens_freed
+=
tokens_occupied
if
(
if
(
self
.
rem_chunk_tokens
is
None
self
.
rem_chunk_tokens
is
None
# chunked prefill is disabled
or
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
or
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
# it is the last chunk
):
):
# Non-chunked prefill
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
self
.
can_run_list
.
append
(
req
)
...
@@ -411,10 +414,11 @@ class PrefillAdder:
...
@@ -411,10 +414,11 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS_ESTIMATION
),
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS_ESTIMATION
),
)
)
else
:
else
:
if
self
.
rem_chunk_tokens
==
0
:
return
AddReqResult
.
OTHER
# Chunked prefill
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
...
@@ -457,10 +461,11 @@ class PrefillAdder:
...
@@ -457,10 +461,11 @@ class PrefillAdder:
),
),
)
)
else
:
else
:
if
self
.
rem_chunk_tokens
==
0
:
return
AddReqResult
.
OTHER
# Chunked prefill
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
req
.
extend_input_len
=
trunc_len
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
d3d4d767
...
@@ -164,7 +164,7 @@ class Scheduler:
...
@@ -164,7 +164,7 @@ class Scheduler:
self
.
server_args
.
speculative_num_draft_tokens
self
.
server_args
.
speculative_num_draft_tokens
+
(
+
(
self
.
server_args
.
speculative_eagle_topk
self
.
server_args
.
speculative_eagle_topk
*
self
.
server_args
.
speculative_num_
step
s
*
self
.
server_args
.
speculative_num_
draft_token
s
)
)
)
)
if
not
self
.
spec_algorithm
.
is_none
()
if
not
self
.
spec_algorithm
.
is_none
()
...
@@ -309,7 +309,9 @@ class Scheduler:
...
@@ -309,7 +309,9 @@ class Scheduler:
)
)
# Init memory pool and cache
# Init memory pool and cache
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
=
self
.
tp_worker
.
get_memory_pool
()
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
=
(
self
.
tp_worker
.
get_memory_pool
()
)
if
(
if
(
server_args
.
chunked_prefill_size
is
not
None
server_args
.
chunked_prefill_size
is
not
None
...
@@ -317,18 +319,18 @@ class Scheduler:
...
@@ -317,18 +319,18 @@ class Scheduler:
):
):
self
.
tree_cache
=
ChunkCache
(
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
self
.
token_to_kv_pool
_allocator
,
)
)
else
:
else
:
if
self
.
enable_hierarchical_cache
:
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
=
HiRadixCache
(
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
self
.
token_to_kv_pool
_allocator
,
)
)
else
:
else
:
self
.
tree_cache
=
RadixCache
(
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
self
.
token_to_kv_pool
_allocator
,
disable
=
server_args
.
disable_radix_cache
,
disable
=
server_args
.
disable_radix_cache
,
)
)
...
@@ -458,7 +460,6 @@ class Scheduler:
...
@@ -458,7 +460,6 @@ class Scheduler:
(
ResumeMemoryOccupationReqInput
,
self
.
resume_memory_occupation
),
(
ResumeMemoryOccupationReqInput
,
self
.
resume_memory_occupation
),
(
ProfileReq
,
self
.
profile
),
(
ProfileReq
,
self
.
profile
),
(
GetInternalStateReq
,
self
.
get_internal_state
),
(
GetInternalStateReq
,
self
.
get_internal_state
),
(
SetInternalStateReq
,
self
.
set_internal_state
),
]
]
)
)
...
@@ -809,7 +810,8 @@ class Scheduler:
...
@@ -809,7 +810,8 @@ class Scheduler:
running_bs
:
int
,
running_bs
:
int
,
):
):
num_used
=
self
.
max_total_num_tokens
-
(
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
self
.
_largest_prefill_len
=
max
(
self
.
_largest_prefill_len
=
max
(
self
.
_largest_prefill_len
,
adder
.
log_input_tokens
self
.
_largest_prefill_len
,
adder
.
log_input_tokens
...
@@ -844,7 +846,8 @@ class Scheduler:
...
@@ -844,7 +846,8 @@ class Scheduler:
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
num_used
=
self
.
max_total_num_tokens
-
(
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
if
RECORD_STEP_TIME
:
if
RECORD_STEP_TIME
:
...
@@ -894,7 +897,8 @@ class Scheduler:
...
@@ -894,7 +897,8 @@ class Scheduler:
def
check_memory
(
self
):
def
check_memory
(
self
):
available_size
=
(
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
)
protected_size
=
self
.
tree_cache
.
protected_size
()
protected_size
=
self
.
tree_cache
.
protected_size
()
memory_leak
=
available_size
!=
(
memory_leak
=
available_size
!=
(
...
@@ -999,7 +1003,7 @@ class Scheduler:
...
@@ -999,7 +1003,7 @@ class Scheduler:
# Prefill policy
# Prefill policy
adder
=
PrefillAdder
(
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
tree_cache
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
_allocator
,
self
.
running_batch
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
new_token_ratio
,
self
.
max_prefill_tokens
,
self
.
max_prefill_tokens
,
...
@@ -1099,7 +1103,7 @@ class Scheduler:
...
@@ -1099,7 +1103,7 @@ class Scheduler:
new_batch
=
ScheduleBatch
.
init_new
(
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
can_run_list
,
self
.
req_to_token_pool
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
_allocator
,
self
.
tree_cache
,
self
.
tree_cache
,
self
.
model_config
,
self
.
model_config
,
self
.
enable_overlap
,
self
.
enable_overlap
,
...
@@ -1143,8 +1147,6 @@ class Scheduler:
...
@@ -1143,8 +1147,6 @@ class Scheduler:
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
(
self
.
server_args
)
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
(
self
.
server_args
)
self
.
new_token_ratio
=
new_token_ratio
self
.
new_token_ratio
=
new_token_ratio
if
self
.
draft_worker
:
self
.
draft_worker
.
finish_request
(
retracted_reqs
)
logger
.
info
(
logger
.
info
(
"Decode out of memory happened. "
"Decode out of memory happened. "
...
@@ -1184,11 +1186,12 @@ class Scheduler:
...
@@ -1184,11 +1186,12 @@ class Scheduler:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
bid
=
model_worker_batch
.
bid
else
:
else
:
(
(
logits_output
,
logits_output
,
next_token_ids
,
next_token_ids
,
model_worker_batch
,
bid
,
num_accepted_tokens
,
num_accepted_tokens
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
spec_num_total_accepted_tokens
+=
(
self
.
spec_num_total_accepted_tokens
+=
(
...
@@ -1214,7 +1217,7 @@ class Scheduler:
...
@@ -1214,7 +1217,7 @@ class Scheduler:
next_token_ids
=
next_token_ids
,
next_token_ids
=
next_token_ids
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
bid
=
model_worker_batch
.
bid
,
bid
=
bid
,
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -1230,6 +1233,7 @@ class Scheduler:
...
@@ -1230,6 +1233,7 @@ class Scheduler:
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
):
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
assert
isinstance
(
result
,
GenerationBatchResult
)
self
.
process_batch_result_decode
(
batch
,
result
)
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
if
batch
.
is_empty
():
self
.
running_batch
=
None
self
.
running_batch
=
None
...
@@ -1302,7 +1306,7 @@ class Scheduler:
...
@@ -1302,7 +1306,7 @@ class Scheduler:
if
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
req
.
finished
():
if
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token for the mixed decode batch
# Free the one delayed token for the mixed decode batch
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
self
.
token_to_kv_pool
_allocator
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
continue
continue
if
req
.
is_chunked
<=
0
:
if
req
.
is_chunked
<=
0
:
...
@@ -1420,23 +1424,27 @@ class Scheduler:
...
@@ -1420,23 +1424,27 @@ class Scheduler:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
assert
batch
.
spec_algorithm
.
is_none
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
next_token_logprobs
=
logits_output
.
next_token_logprobs
else
:
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
self
.
token_to_kv_pool
.
free_group_begin
()
self
.
token_to_kv_pool
_allocator
.
free_group_begin
()
# Check finish condition
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
if
req
.
is_retracted
:
continue
continue
if
self
.
enable_overlap
and
req
.
finished
():
if
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token
# Free the one delayed token
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
self
.
token_to_kv_pool
_allocator
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
continue
if
batch
.
spec_algorithm
.
is_none
():
if
batch
.
spec_algorithm
.
is_none
():
...
@@ -1479,7 +1487,7 @@ class Scheduler:
...
@@ -1479,7 +1487,7 @@ class Scheduler:
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
token_to_kv_pool
.
free_group_end
()
self
.
token_to_kv_pool
_allocator
.
free_group_end
()
self
.
forward_ct_decode
=
(
self
.
forward_ct_decode
+
1
)
%
(
1
<<
30
)
self
.
forward_ct_decode
=
(
self
.
forward_ct_decode
+
1
)
%
(
1
<<
30
)
if
(
if
(
...
@@ -1718,9 +1726,6 @@ class Scheduler:
...
@@ -1718,9 +1726,6 @@ class Scheduler:
and
not
self
.
model_config
.
is_multimodal_gen
and
not
self
.
model_config
.
is_multimodal_gen
)
)
):
):
if
self
.
draft_worker
and
req
.
finished
():
self
.
draft_worker
.
finish_request
(
req
)
rids
.
append
(
req
.
rid
)
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
...
@@ -1860,7 +1865,7 @@ class Scheduler:
...
@@ -1860,7 +1865,7 @@ class Scheduler:
idle_batch
=
ScheduleBatch
.
init_new
(
idle_batch
=
ScheduleBatch
.
init_new
(
[],
[],
self
.
req_to_token_pool
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
_allocator
,
self
.
tree_cache
,
self
.
tree_cache
,
self
.
model_config
,
self
.
model_config
,
self
.
enable_overlap
,
self
.
enable_overlap
,
...
@@ -1916,11 +1921,11 @@ class Scheduler:
...
@@ -1916,11 +1921,11 @@ class Scheduler:
if
self
.
grammar_backend
:
if
self
.
grammar_backend
:
self
.
grammar_backend
.
reset
()
self
.
grammar_backend
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
self
.
token_to_kv_pool
_allocator
.
clear
()
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
self
.
draft_worker
.
model_runner
.
req_to_token_pool
.
clear
()
self
.
draft_worker
.
model_runner
.
req_to_token_pool
.
clear
()
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
.
clear
()
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
_allocator
.
clear
()
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
forward_ct_decode
=
0
self
.
forward_ct_decode
=
0
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
d3d4d767
...
@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
...
@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
ResumeMemoryOccupationReqOutput
,
SessionParams
,
SessionParams
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
...
@@ -257,9 +255,6 @@ class TokenizerManager:
...
@@ -257,9 +255,6 @@ class TokenizerManager:
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
self
.
set_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
[
[
...
@@ -309,10 +304,6 @@ class TokenizerManager:
...
@@ -309,10 +304,6 @@ class TokenizerManager:
GetInternalStateReqOutput
,
GetInternalStateReqOutput
,
self
.
get_internal_state_communicator
.
handle_recv
,
self
.
get_internal_state_communicator
.
handle_recv
,
),
),
(
SetInternalStateReqOutput
,
self
.
set_internal_state_communicator
.
handle_recv
,
),
(
HealthCheckOutput
,
lambda
x
:
None
),
(
HealthCheckOutput
,
lambda
x
:
None
),
]
]
)
)
...
@@ -774,14 +765,6 @@ class TokenizerManager:
...
@@ -774,14 +765,6 @@ class TokenizerManager:
)
)
return
res
[
0
].
internal_state
return
res
[
0
].
internal_state
async
def
set_internal_state
(
self
,
obj
:
SetInternalStateReq
)
->
SetInternalStateReqOutput
:
res
:
List
[
SetInternalStateReqOutput
]
=
(
await
self
.
set_internal_state_communicator
(
obj
)
)
return
res
[
0
]
def
get_log_request_metadata
(
self
):
def
get_log_request_metadata
(
self
):
max_length
=
None
max_length
=
None
skip_names
=
None
skip_names
=
None
...
...
python/sglang/srt/managers/tp_worker.py
View file @
d3d4d767
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
UpdateWeightsFromTensorReqInput
,
)
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
@@ -49,6 +50,8 @@ class TpModelWorker:
...
@@ -49,6 +50,8 @@ class TpModelWorker:
dp_rank
:
Optional
[
int
],
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
):
# Parse args
# Parse args
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
...
@@ -77,6 +80,8 @@ class TpModelWorker:
...
@@ -77,6 +80,8 @@ class TpModelWorker:
nccl_port
=
nccl_port
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
server_args
=
server_args
,
is_draft_worker
=
is_draft_worker
,
is_draft_worker
=
is_draft_worker
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
,
)
)
if
server_args
.
skip_tokenizer_init
:
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
self
.
tokenizer
=
self
.
processor
=
None
...
@@ -154,7 +159,7 @@ class TpModelWorker:
...
@@ -154,7 +159,7 @@ class TpModelWorker:
def
get_memory_pool
(
self
):
def
get_memory_pool
(
self
):
return
(
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool
,
self
.
model_runner
.
token_to_kv_pool
_allocator
,
)
)
def
forward_batch_generation
(
def
forward_batch_generation
(
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
d3d4d767
...
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
...
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
def
get_memory_pool
(
self
):
def
get_memory_pool
(
self
):
return
(
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
self
.
worker
.
model_runner
.
req_to_token_pool
,
self
.
worker
.
model_runner
.
token_to_kv_pool
,
self
.
worker
.
model_runner
.
token_to_kv_pool
_allocator
,
)
)
def
forward_thread_func
(
self
):
def
forward_thread_func
(
self
):
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
d3d4d767
from
__future__
import
annotations
from
__future__
import
annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
...
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
class
ChunkCache
(
BasePrefixCache
):
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
):
):
self
.
disable
=
True
self
.
disable
=
True
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
self
.
entries
:
Dict
[
str
,
ChunkCacheEntry
]
=
{}
self
.
entries
:
Dict
[
str
,
ChunkCacheEntry
]
=
{}
self
.
reset
()
self
.
reset
()
...
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
req
.
req_pool_idx
,
:
token_id_len
req
.
req_pool_idx
,
:
token_id_len
]
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
kv_indices
)
if
req
.
rid
in
self
.
entries
:
if
req
.
rid
in
self
.
entries
:
del
self
.
entries
[
req
.
rid
]
del
self
.
entries
[
req
.
rid
]
...
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
...
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
def
protected_size
(
self
):
def
protected_size
(
self
):
return
0
return
0
def
pretty_print
(
self
):
return
""
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
d3d4d767
...
@@ -7,8 +7,8 @@ import torch
...
@@ -7,8 +7,8 @@ import torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
Base
TokenToKVPool
,
MHA
TokenToKVPool
,
M
L
ATokenToKVPoolHost
,
M
H
ATokenToKVPoolHost
,
ReqToTokenPool
,
ReqToTokenPool
,
)
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
,
_key_match
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
,
_key_match
...
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
...
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
Base
TokenToKVPool
,
token_to_kv_pool
:
MHA
TokenToKVPool
,
):
):
self
.
token_to_kv_pool_host
=
M
L
ATokenToKVPoolHost
(
token_to_kv_pool
)
self
.
token_to_kv_pool_host
=
M
H
ATokenToKVPoolHost
(
token_to_kv_pool
)
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool
,
self
.
token_to_kv_pool_host
token_to_kv_pool
,
self
.
token_to_kv_pool_host
)
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
d3d4d767
...
@@ -20,9 +20,12 @@ Memory pool.
...
@@ -20,9 +20,12 @@ Memory pool.
SGLang has two levels of memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
"""
"""
import
abc
import
logging
import
logging
import
threading
import
threading
from
enum
import
IntEnum
from
enum
import
IntEnum
...
@@ -89,7 +92,7 @@ class ReqToTokenPool:
...
@@ -89,7 +92,7 @@ class ReqToTokenPool:
self
.
free_slots
=
list
(
range
(
self
.
size
))
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
Base
TokenToKVPool
:
class
TokenToKVPool
Allocator
:
"""A memory pool that maps a token location to its kv cache data."""
"""A memory pool that maps a token location to its kv cache data."""
def
__init__
(
def
__init__
(
...
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
...
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
):
):
self
.
size
=
size
self
.
size
=
size
self
.
dtype
=
dtype
self
.
dtype
=
dtype
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
free_slots
=
None
self
.
free_slots
=
None
...
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
...
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
self
.
is_in_free_group
=
False
self
.
is_in_free_group
=
False
self
.
free_group
=
[]
self
.
free_group
=
[]
class
KVCache
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
set_kv_buffer
(
def
set_kv_buffer
(
self
,
self
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
...
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
...
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
raise
NotImplementedError
()
raise
NotImplementedError
()
class
MHATokenToKVPool
(
BaseTokenToKVPool
):
class
MHATokenToKVPool
(
KVCache
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
):
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
enable
=
enable_memory_saver
)
)
...
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
...
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
class
MLATokenToKVPool
(
BaseTokenToKVPool
):
class
MLATokenToKVPool
(
KVCache
):
def
__init__
(
def
__init__
(
self
,
self
,
size
:
int
,
size
:
int
,
...
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
...
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
):
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
...
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
...
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
class
DoubleSparseTokenToKVPool
(
BaseTokenToKVPool
):
class
DoubleSparseTokenToKVPool
(
KVCache
):
def
__init__
(
def
__init__
(
self
,
self
,
size
:
int
,
size
:
int
,
...
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
...
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
heavy_channel_num
:
int
,
heavy_channel_num
:
int
,
enable_memory_saver
:
bool
,
enable_memory_saver
:
bool
,
):
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
enable
=
enable_memory_saver
)
)
...
@@ -437,12 +460,12 @@ def synchronized(func):
...
@@ -437,12 +460,12 @@ def synchronized(func):
return
wrapper
return
wrapper
class
M
L
ATokenToKVPoolHost
:
class
M
H
ATokenToKVPoolHost
:
def
__init__
(
def
__init__
(
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
=
4
.0
,
host_to_device_ratio
:
float
=
2
.0
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
d3d4d767
...
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
...
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import
torch
import
torch
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.schedule_batch
import
Req
...
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
...
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
def
__init__
(
def
__init__
(
self
,
self
,
req_to_token_pool
:
ReqToTokenPool
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
Base
TokenToKVPool
,
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
,
disable
:
bool
=
False
,
disable
:
bool
=
False
,
):
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
self
.
disable
=
disable
self
.
disable
=
disable
self
.
reset
()
self
.
reset
()
...
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
...
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_ids_len
req
.
req_pool_idx
,
:
token_ids_len
]
]
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
return
return
...
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
...
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
)
# Remove req slot release the cache lock
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
...
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
)
# The prefix indices could be updated, reuse it
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
d3d4d767
...
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
)
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -98,6 +99,8 @@ class ModelRunner:
...
@@ -98,6 +99,8 @@ class ModelRunner:
nccl_port
:
int
,
nccl_port
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
):
# Parse args
# Parse args
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
@@ -115,6 +118,8 @@ class ModelRunner:
...
@@ -115,6 +118,8 @@ class ModelRunner:
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
# Model-specific adjustment
# Model-specific adjustment
if
(
if
(
...
@@ -257,8 +262,8 @@ class ModelRunner:
...
@@ -257,8 +262,8 @@ class ModelRunner:
def
init_torch_distributed
(
self
):
def
init_torch_distributed
(
self
):
logger
.
info
(
"Init torch distributed begin."
)
logger
.
info
(
"Init torch distributed begin."
)
torch
.
get_device_module
(
self
.
device
).
set_device
(
self
.
gpu_id
)
torch
.
get_device_module
(
self
.
device
).
set_device
(
self
.
gpu_id
)
if
self
.
device
==
"cuda"
:
if
self
.
device
==
"cuda"
:
backend
=
"nccl"
backend
=
"nccl"
elif
self
.
device
==
"xpu"
:
elif
self
.
device
==
"xpu"
:
...
@@ -660,12 +665,25 @@ class ModelRunner:
...
@@ -660,12 +665,25 @@ class ModelRunner:
if
not
self
.
spec_algorithm
.
is_none
():
if
not
self
.
spec_algorithm
.
is_none
():
if
self
.
is_draft_worker
:
if
self
.
is_draft_worker
:
self
.
max_total_num_tokens
=
self
.
server_args
.
draft_runner_cache_size
self
.
max_total_num_tokens
=
self
.
server_args
.
draft_runner_cache_size
max_num_reqs
=
self
.
server_args
.
max_num_reqs
else
:
else
:
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
# can be concurrently allocated, so we should give a headroom for it.
self
.
server_args
.
draft_runner_cache_size
=
(
self
.
server_args
.
draft_runner_cache_size
=
(
self
.
max_total_num_tokens
self
.
max_total_num_tokens
+
max_num_reqs
*
self
.
server_args
.
speculative_num_steps
# draft
+
max_num_reqs
*
self
.
server_args
.
speculative_num_steps
*
self
.
server_args
.
speculative_eagle_topk
# verify
+
max_num_reqs
*
self
.
server_args
.
speculative_num_draft_tokens
# buffer
+
100
+
100
)
)
# Target worker and draft worker shares the same indices for the
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
self
.
max_total_num_tokens
=
self
.
server_args
.
draft_runner_cache_size
self
.
server_args
.
max_num_reqs
=
max_num_reqs
if
max_total_tokens
is
not
None
:
if
max_total_tokens
is
not
None
:
if
max_total_tokens
>
self
.
max_total_num_tokens
:
if
max_total_tokens
>
self
.
max_total_num_tokens
:
...
@@ -681,12 +699,25 @@ class ModelRunner:
...
@@ -681,12 +699,25 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
"Not enough memory. Please try to increase --mem-fraction-static."
)
)
if
self
.
req_to_token_pool
is
None
:
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
size
=
max_num_reqs
+
1
,
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
)
else
:
# Draft worker shares req_to_token_pool with the target worker.
assert
self
.
is_draft_worker
if
self
.
token_to_kv_pool_allocator
is
None
:
self
.
token_to_kv_pool_allocator
=
TokenToKVPoolAllocator
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
,
)
else
:
assert
self
.
is_draft_worker
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
python/sglang/srt/server_args.py
View file @
d3d4d767
...
@@ -280,11 +280,16 @@ class ServerArgs:
...
@@ -280,11 +280,16 @@ class ServerArgs:
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
self
.
prefill_only_one_req
=
True
self
.
prefill_only_one_req
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_radix_cache
=
True
if
self
.
max_running_requests
is
None
:
self
.
chunked_prefill_size
=
-
1
self
.
max_running_requests
=
32
logger
.
info
(
logger
.
info
(
f
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using
{
self
.
speculative_algorithm
}
speculative decoding."
"Overlap scheduler are disabled because of using "
"eagle speculative decoding."
"Max running request set to 32 because of using eagle speculative decoding."
)
)
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
assert
self
.
speculative_num_steps
<
self
.
speculative_num_draft_tokens
# GGUF
# GGUF
if
(
if
(
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
d3d4d767
...
@@ -3,14 +3,8 @@
...
@@ -3,14 +3,8 @@
from
typing
import
List
from
typing
import
List
import
torch
import
torch
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sglang.srt.utils
import
is_cuda_available
from
sgl_kernel
import
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
if
is_cuda_available
():
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
def
build_tree_kernel_efficient_preprocess
(
def
build_tree_kernel_efficient_preprocess
(
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
d3d4d767
...
@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
d3d4d767
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclass
es
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
from
typing
import
TYPE_CHECKING
,
Dict
,
List
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
create_flashinfer_kv_indices_triton
,
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
(
from
sglang.srt.speculative.build_eagle_tree
import
(
build_tree_kernel
,
build_tree_kernel
,
...
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
...
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
@
dataclass
es
.
dataclass
@
dataclass
class
EagleDraftInput
:
class
EagleDraftInput
:
# The inputs for decode
# The inputs for decode
# shape: (b, topk)
# shape: (b, topk)
...
@@ -46,57 +47,46 @@ class EagleDraftInput:
...
@@ -46,57 +47,46 @@ class EagleDraftInput:
kv_indptr
:
torch
.
Tensor
=
None
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
# indices of unfinished requests during extend-after-decode
# e.g. [0, 2, 3, 4] if only the 1st request is finished
keep_indices
:
List
[
int
]
=
None
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
req_pool_indices
=
batch
.
alloc_req_slots
(
len
(
batch
.
reqs
))
assert
batch
.
input_ids
.
numel
()
==
batch
.
out_cache_loc
.
shape
[
0
]
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
# Prefill only generate 1 token.
batch
.
out_cache_loc
=
out_cache_loc
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
pt
=
0
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
for
i
,
extend_len
in
enumerate
(
batch
.
extend_lens
):
req
.
req_pool_idx
=
req_pool_indices
[
i
]
input_ids
=
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
fill_ids
)
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
=
torch
.
concat
(
assert
seq_len
-
pre_len
==
req
.
extend_input_len
(
input_ids
[
1
:],
self
.
verified_id
[
i
].
reshape
(
1
))
if
pre_len
>
0
:
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
pre_len
]
=
req
.
prefix_indices
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
pre_len
:
seq_len
]
=
(
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
)
)
pt
+=
req
.
extend_input_len
# TODO: support batching inputs
assert
len
(
batch
.
extend_lens
)
==
1
batch
.
input_ids
=
torch
.
concat
((
batch
.
input_ids
[
1
:],
self
.
verified_id
))
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
speculative_num_steps
):
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
speculative_num_steps
):
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
self
.
verified_id
.
numel
())
assert
self
.
verified_id
.
numel
()
==
batch
.
out_cache_loc
.
shape
[
0
]
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
seq_lens_cpu
=
batch
.
seq_lens
.
tolist
()
seq_lens_cpu
=
batch
.
seq_lens
.
tolist
()
assert
len
(
batch
.
req_pool_indices
)
==
len
(
batch
.
reqs
)
pt
=
0
pt
=
0
i
=
0
i
=
0
for
req
in
batch
.
reqs
:
self
.
keep_indices
=
[]
for
idx
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
if
req
.
finished
():
continue
continue
self
.
keep_indices
.
append
(
idx
)
# assert seq_len - pre_len == req.extend_input_len
# assert seq_len - pre_len == req.extend_input_len
input_len
=
batch
.
extend_lens
[
i
]
input_len
=
batch
.
extend_lens
[
i
]
seq_len
=
seq_lens_cpu
[
i
]
seq_len
=
seq_lens_cpu
[
i
]
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
seq_len
-
input_len
:
seq_len
]
=
batch
.
out_cache_loc
[
pt
:
pt
+
input_len
]
pt
+=
input_len
pt
+=
input_len
i
+=
1
i
+=
1
assert
pt
==
batch
.
out_cache_loc
.
shape
[
0
]
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
)
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
,
dtype
=
torch
.
long
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
long
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
self
.
accept_length
.
add_
(
1
)
self
.
accept_length
.
add_
(
1
)
create_extend_spec_info
[(
self
.
accept_length
.
numel
(),)](
create_extend_spec_info
[(
self
.
accept_length
.
numel
(),)](
...
@@ -117,14 +107,22 @@ class EagleDraftInput:
...
@@ -117,14 +107,22 @@ class EagleDraftInput:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
):
):
bs
=
self
.
accept_length
.
numel
()
bs
=
self
.
accept_length
.
numel
()
keep_indices
=
torch
.
tensor
(
self
.
keep_indices
,
device
=
req_pool_indices
.
device
)
req_pool_indices
=
req_pool_indices
[
keep_indices
]
assert
req_pool_indices
.
shape
[
0
]
==
bs
assert
req_pool_indices
.
shape
[
0
]
==
paged_kernel_lens
.
shape
[
0
]
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
kv_indices
=
torch
.
empty
(
cum_kv_seq_len
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
empty
(
cum_kv_seq_len
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
@@ -162,7 +160,21 @@ class EagleDraftInput:
...
@@ -162,7 +160,21 @@ class EagleDraftInput:
self
.
topk_index
=
torch
.
cat
([
self
.
topk_index
,
spec_info
.
topk_index
])
self
.
topk_index
=
torch
.
cat
([
self
.
topk_index
,
spec_info
.
topk_index
])
@
dataclasses
.
dataclass
@
dataclass
class
EagleVerifyOutput
:
# Draft input batch
draft_input
:
EagleDraftInput
# Logit outputs from target worker
logits_output
:
LogitsProcessorOutput
# Accepeted token ids including the bonus token
verified_id
:
torch
.
Tensor
# Accepeted token length per sequence in a batch in CPU.
accept_length_per_req_cpu
:
List
[
int
]
# Accepeted indices from logits_output.next_token_logits
accepeted_indices_cpu
:
List
[
int
]
@
dataclass
class
EagleVerifyInput
:
class
EagleVerifyInput
:
draft_token
:
torch
.
Tensor
draft_token
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
...
@@ -267,6 +279,7 @@ class EagleVerifyInput:
...
@@ -267,6 +279,7 @@ class EagleVerifyInput:
self
,
self
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
req_to_token
:
torch
.
Tensor
,
):
):
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
...
@@ -285,7 +298,11 @@ class EagleVerifyInput:
...
@@ -285,7 +298,11 @@ class EagleVerifyInput:
paged_kernel_lens
=
paged_kernel_lens
+
self
.
draft_token_num
paged_kernel_lens
=
paged_kernel_lens
+
self
.
draft_token_num
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
cum_kv_seq_len
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
+
self
.
draft_token_num
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
req_to_token
,
req_to_token
,
...
@@ -298,7 +315,21 @@ class EagleVerifyInput:
...
@@ -298,7 +315,21 @@ class EagleVerifyInput:
)
)
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
self
.
custom_mask
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
self
.
custom_mask
def
verify
(
self
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
verify
(
self
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
)
->
torch
.
Tensor
:
"""WARNING: This API in-place modifies the states of logits_output
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
This API updates values inside logits_output based on the accepted
tokens. I.e., logits_output.next_token_logits only contains
accepeted token logits.
"""
draft_token
=
torch
.
cat
(
draft_token
=
torch
.
cat
(
[
self
.
draft_token
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
[
self
.
draft_token
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
dim
=-
1
,
dim
=-
1
,
...
@@ -367,7 +398,6 @@ class EagleVerifyInput:
...
@@ -367,7 +398,6 @@ class EagleVerifyInput:
new_accept_index
=
[]
new_accept_index
=
[]
unfinished_index
=
[]
unfinished_index
=
[]
finished_extend_len
=
{}
# {rid:accept_length + 1}
accept_index_cpu
=
accept_index
.
tolist
()
accept_index_cpu
=
accept_index
.
tolist
()
predict_cpu
=
predict
.
tolist
()
predict_cpu
=
predict
.
tolist
()
has_finished
=
False
has_finished
=
False
...
@@ -382,7 +412,6 @@ class EagleVerifyInput:
...
@@ -382,7 +412,6 @@ class EagleVerifyInput:
id
=
predict_cpu
[
idx
]
id
=
predict_cpu
[
idx
]
# if not found_finished:
# if not found_finished:
req
.
output_ids
.
append
(
id
)
req
.
output_ids
.
append
(
id
)
finished_extend_len
[
req
.
rid
]
=
j
+
1
req
.
check_finished
()
req
.
check_finished
()
if
req
.
finished
():
if
req
.
finished
():
has_finished
=
True
has_finished
=
True
...
@@ -400,11 +429,10 @@ class EagleVerifyInput:
...
@@ -400,11 +429,10 @@ class EagleVerifyInput:
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
()
verified_id
=
predict
[
accept_index
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
batch
.
token_to_kv_pool
.
free
(
mem_need_free_idx
)
token_to_kv_pool
_allocator
.
free
(
mem_need_free_idx
)
assign_req_to_token_pool
[(
bs
,)](
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
...
@@ -427,20 +455,16 @@ class EagleVerifyInput:
...
@@ -427,20 +455,16 @@ class EagleVerifyInput:
]
]
if
has_finished
:
if
has_finished
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
]
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
]
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
[
unfinished_index
]
else
:
else
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
new_accept_index
]
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
return
EagleVerifyOutput
(
return
(
draft_input
=
draft_input
,
draft_input
,
logits_output
=
logits_output
,
logits_output
,
verified_id
=
verified_id
,
verified_id
,
accept_length_per_req_cpu
=
accept_length_cpu
,
finished_extend_len
,
accepeted_indices_cpu
=
accept_index
,
accept_length_cpu
,
)
)
...
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
...
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
draft_token_num
:
tl
.
constexpr
,
draft_token_num
:
tl
.
constexpr
,
max_len_upper
:
tl
.
constexpr
,
max_len_upper
:
tl
.
constexpr
,
):
):
"""
Args:
retrive_index: Pointer to indices of draft tokens
accept_mask: Mask indicating which tokens were accepted
retrive_cum_len: Cumulative lengths of token sequences in a batch
accept_index (out): Accept token indices
accept_length (out): Length of accepted tokens per sequence in a batch
extract_index (out): Index for last accepted tokens
max_len: Maximum length in a batch
draft_token_num: Number of tokens speculatively generated
max_len_upper An upper bound for token sequence length
"""
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
retrive_end
=
tl
.
load
(
retrive_cum_len
+
pid
+
1
)
retrive_end
=
tl
.
load
(
retrive_cum_len
+
pid
+
1
)
...
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
...
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
@
torch
.
compile
@
torch
.
compile
(
dynamic
=
True
)
def
select_top_k_tokens
(
def
select_top_k_tokens
(
i
:
int
,
i
:
int
,
topk_p
:
torch
.
Tensor
,
topk_p
:
torch
.
Tensor
,
...
@@ -671,13 +707,11 @@ def select_top_k_tokens(
...
@@ -671,13 +707,11 @@ def select_top_k_tokens(
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
)
)
else
:
else
:
# The later decode steps
# The later decode steps
expand_scores
=
torch
.
mul
(
expand_scores
=
torch
.
mul
(
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
topk
,
topk
)
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
topk
,
topk
)
)
# (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
)
# (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p
,
topk_cs_index
=
fast_topk
(
topk_cs_p
,
topk_cs_index
=
fast_topk
(
expand_scores
.
flatten
(
start_dim
=
1
),
topk
,
dim
=-
1
expand_scores
.
flatten
(
start_dim
=
1
),
topk
,
dim
=-
1
)
# (b, topk)
)
# (b, topk)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
d3d4d767
import
logging
import
logging
import
os
import
os
import
time
import
time
from
typing
import
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
...
@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
...
@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from
sglang.srt.speculative.eagle_utils
import
(
from
sglang.srt.speculative.eagle_utils
import
(
EagleDraftInput
,
EagleDraftInput
,
EagleVerifyInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
assign_draft_cache_locs
,
assign_draft_cache_locs
,
fast_topk
,
fast_topk
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
get_available_gpu_memory
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
...
@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
nccl_port
:
int
,
nccl_port
:
int
,
target_worker
:
TpModelWorker
,
target_worker
:
TpModelWorker
,
):
):
# Override context length with target model's context length
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
os
.
environ
[
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"
]
=
"1"
# Do not capture cuda graph in `super().__init__()`
# Do not capture cuda graph in `super().__init__()`
# We will capture it later
# We will capture it later
backup_disable_cuda_graph
=
server_args
.
disable_cuda_graph
backup_disable_cuda_graph
=
server_args
.
disable_cuda_graph
server_args
.
disable_cuda_graph
=
True
server_args
.
disable_cuda_graph
=
True
# Lo
ad
hot token
id
s
# Lo
ssy optimization by using
hot tokens
if
server_args
.
speculative_token_map
is
not
None
:
if
server_args
.
speculative_token_map
is
not
None
:
self
.
hot_token_id
=
load_token_map
(
server_args
.
speculative_token_map
)
self
.
hot_token_id
=
load_token_map
(
server_args
.
speculative_token_map
)
server_args
.
json_model_override_args
=
(
server_args
.
json_model_override_args
=
(
...
@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
...
@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
else
:
else
:
self
.
hot_token_id
=
None
self
.
hot_token_id
=
None
# We share the allocator with a target worker. Draft/target worker
# owns its own KV cache.
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
=
(
target_worker
.
get_memory_pool
()
)
# Init target worker
# Init target worker
super
().
__init__
(
super
().
__init__
(
gpu_id
=
gpu_id
,
gpu_id
=
gpu_id
,
...
@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
nccl_port
=
nccl_port
,
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
dp_rank
=
dp_rank
,
is_draft_worker
=
True
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
)
self
.
target_worker
=
target_worker
self
.
target_worker
=
target_worker
self
.
finish_extend_len
=
[]
# Parse arguments
# Parse arguments
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
topk
=
server_args
.
speculative_eagle_topk
...
@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
use_nan_detection
=
self
.
server_args
.
enable_nan_detection
self
.
device
=
self
.
model_runner
.
device
self
.
gpu_id
=
self
.
model_runner
.
gpu_id
# Share the embedding and lm_head
# Share the embedding and lm_head
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
...
@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
head
=
head
.
clone
()
head
=
head
.
clone
()
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
draft_model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
)
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
if
server_args
.
attention_backend
==
"flashinfer"
:
if
server_args
.
attention_backend
==
"flashinfer"
:
...
@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
f
"EAGLE is not supportted in attention backend
{
server_args
.
attention_backend
}
"
f
"EAGLE is not supportted in attention backend
{
server_args
.
attention_backend
}
"
)
)
self
.
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
draft_
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
init_cuda_graphs
()
self
.
init_cuda_graphs
()
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
...
@@ -122,36 +140,71 @@ class EAGLEWorker(TpModelWorker):
...
@@ -122,36 +140,71 @@ class EAGLEWorker(TpModelWorker):
return
return
tic
=
time
.
time
()
tic
=
time
.
time
()
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
logger
.
info
(
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
@
property
def
draft_model_runner
(
self
):
return
self
.
model_runner
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
,
int
]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch doesn't have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
Returns:
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert
not
batch
.
spec_algorithm
.
is_none
()
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
# Draft
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
spec_info
:
EagleVerifyInput
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
# Verify
)
(
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
next_draft_input
,
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
logits_output
,
# if it is None, means all requests are finished
verified_id
,
self
.
finish_extend_len
,
accept_length_cpu
,
model_worker_batch
,
)
=
self
.
verify
(
batch
,
spec_info
)
batch
.
spec_info
=
next_draft_input
# if it is None, means all requsets are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_draft_extend_after_decode
(
batch
)
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
return
(
logits_output
,
logits_output
,
verified_id
,
verify_output
.
verified_id
,
model_worker_batch
,
model_worker_batch
.
bid
,
sum
(
accept_length_cpu
),
sum
(
verify_output
.
accept_length_
per_req_
cpu
),
)
)
else
:
else
:
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
return
logits_output
,
next_token_ids
,
bid
,
0
def
forward_target_extend
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
]:
"""Run the target extend.
Args:
batch: The batch to run. States could be modified.
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -159,18 +212,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -159,18 +212,9 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
model_worker_batch
)
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
# Forward with the draft model.
batch
.
spec_info
=
EagleDraftInput
(
hidden_states
=
logits_output
.
hidden_states
,
verified_id
=
next_token_ids
,
)
self
.
forward_draft_extend
(
batch
)
return
logits_output
,
next_token_ids
,
model_worker_batch
,
0
def
draft
(
self
,
batch
:
ScheduleBatch
):
def
draft
(
self
,
batch
:
ScheduleBatch
):
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
# Parse args
# Parse args
num_seqs
=
batch
.
batch_size
()
num_seqs
=
batch
.
batch_size
()
spec_info
=
batch
.
spec_info
spec_info
=
batch
.
spec_info
...
@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
batch
.
out_cache_loc
=
out_cache_loc
batch
.
out_cache_loc
=
out_cache_loc
batch
.
seq_lens_sum
=
torch
.
sum
(
batch
.
seq_lens
).
item
()
batch
.
seq_lens_sum
=
torch
.
sum
(
batch
.
seq_lens
).
item
()
spec_info
.
positions
=
batch
.
seq_lens
.
repeat_interleave
(
self
.
topk
,
dim
=
0
)
spec_info
.
positions
=
batch
.
seq_lens
.
repeat_interleave
(
self
.
topk
,
dim
=
0
)
...
@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
...
@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
# Get forward batch
# Get forward batch
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
can_cuda_graph
=
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
can_cuda_graph
=
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
forward_batch
)
)
if
can_cuda_graph
:
if
can_cuda_graph
:
score_list
,
token_list
,
parents_list
=
self
.
cuda_graph_runner
.
replay
(
score_list
,
token_list
,
parents_list
=
self
.
cuda_graph_runner
.
replay
(
forward_batch
forward_batch
...
@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
...
@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
else
:
else
:
# Initialize attention backend
# Initialize attention backend
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
# Run forward steps
# Run forward steps
score_list
,
token_list
,
parents_list
=
self
.
draft_forward
(
forward_batch
)
score_list
,
token_list
,
parents_list
=
self
.
draft_forward
(
forward_batch
)
...
@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
batch
.
sampling_info
.
is_all_greedy
,
batch
.
sampling_info
.
is_all_greedy
,
)
)
# Free cache locations
return
ret
,
out_cache_loc
batch
.
token_to_kv_pool
.
free
(
out_cache_loc
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
return
ret
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
# Parse args
# Parse args
...
@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
logits_output
=
self
.
model_runner
.
model
.
forward
(
logits_output
=
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
topk_p
,
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
topk_p
,
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
if
self
.
hot_token_id
is
not
None
:
if
self
.
hot_token_id
is
not
None
:
...
@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
...
@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
_
=
self
.
target_worker
.
forward_batch_generation
(
logits_output
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
model_worker_batch
,
skip_sample
=
True
)
)
self
.
_detect_nan_if_needed
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
hidden_states
=
logits_output
.
hidden_states
res
=
spec_info
.
verify
(
batch
,
logits_output
)
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
batch
,
logits_output
,
self
.
token_to_kv_pool_allocator
)
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
res
.
accepeted_indices_cpu
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices_cpu
]
# Prepare the batch for the next draft forwards.
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
return
res
+
(
model_worker_batch
,)
batch
.
spec_info
=
res
.
draft_input
return
logits_output
,
res
,
model_worker_batch
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
,
hidden_states
:
torch
.
Tensor
,
next_token_ids
:
List
[
int
],
):
"""Run draft model extend. This API modifies the states of the batch.
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
Args:
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
batch: The batch to run.
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
batch
.
spec_info
=
EagleDraftInput
(
hidden_states
=
hidden_states
,
verified_id
=
next_token_ids
,
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
model_worker_batch
,
self
.
draft_model_runner
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
def
_set_mem_pool
(
self
,
batch
:
ScheduleBatch
,
runner
:
ModelRunner
):
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
batch
.
token_to_kv_pool
=
runner
.
token_to_kv_pool
assert
forward_batch
.
spec_info
is
batch
.
spec_info
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
seq_lens_backup
=
batch
.
seq_lens
seq_lens_backup
=
batch
.
seq_lens
req_pool_indices_backup
=
batch
.
req_pool_indices
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
)
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
# We don't need logprob for this extend.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
model_worker_batch
,
self
.
draft_model_runner
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
# Restore backup.
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
seq_lens
=
seq_lens_backup
batch
.
seq_lens
=
seq_lens_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
def
capture_for_decode
(
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
draft_input
:
EagleDraftInput
):
):
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
spec_info
=
forward_batch
.
spec_info
draft_input
.
topk_p
,
draft_input
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
spec_info
.
topk_p
,
spec_info
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
draft_input
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
hidden_states
=
logits_output
.
hidden_states
def
_detect_nan_if_needed
(
self
,
logits_output
:
LogitsProcessorOutput
):
# Don't support prefix share now.
if
self
.
use_nan_detection
:
def
finish_request
(
self
,
reqs
:
Union
[
Req
,
List
[
Req
]]):
logits
=
logits_output
.
next_token_logits
if
not
isinstance
(
reqs
,
List
):
if
torch
.
any
(
torch
.
isnan
(
logits
)):
reqs
=
[
reqs
]
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
for
req
in
reqs
:
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
if
req
.
rid
not
in
self
.
finish_extend_len
:
continue
req_len
=
(
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
self
.
finish_extend_len
[
req
.
rid
]
-
1
)
kv_indices
=
self
.
model_runner
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][:
req_len
]
self
.
model_runner
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
model_runner
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment