Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d3d4d767
Unverified
Commit
d3d4d767
authored
Mar 05, 2025
by
Ying Sheng
Committed by
GitHub
Mar 05, 2025
Browse files
[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by:
Ke Bao
<
ISPObaoke@163.com
>
parent
5be8f1ed
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
529 additions
and
349 deletions
+529
-349
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+81
-58
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+1
-0
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+26
-24
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+19
-14
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+31
-26
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+0
-17
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+6
-1
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+1
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+9
-5
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+4
-4
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+41
-18
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+11
-6
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+39
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-3
python/sglang/srt/speculative/build_eagle_tree.py
python/sglang/srt/speculative/build_eagle_tree.py
+2
-8
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
...n/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
+0
-1
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+92
-58
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+154
-93
No files found.
python/sglang/bench_one_batch.py
View file @
d3d4d767
...
...
@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
model_runner
.
token_to_kv_pool
_allocator
,
tree_cache
=
None
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
...
...
@@ -326,7 +326,7 @@ def latency_test_run_once(
# Clear the pools.
model_runner
.
req_to_token_pool
.
clear
()
model_runner
.
token_to_kv_pool
.
clear
()
model_runner
.
token_to_kv_pool
_allocator
.
clear
()
measurement_results
=
{
"run_name"
:
run_name
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
d3d4d767
...
...
@@ -20,14 +20,15 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
is_flashinfer_available
():
from
flashinfer
import
(
...
...
@@ -36,6 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.cascade
import
merge_state
from
flashinfer.decode
import
PosEncodingMode
class
WrapperDispatch
(
Enum
):
...
...
@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
device
=
model_runner
.
device
,
)
self
.
workspace_buffer
=
global_workspace_buffer
max_bs
=
model_runner
.
req_to_token_pool
.
size
if
kv_indptr_buf
is
None
:
self
.
kv_indptr
=
[
...
...
@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert
self
.
num_wrappers
==
1
self
.
kv_last_page_len
=
kv_last_page_len_buf
self
.
qo_indptr
=
[
torch
.
zeros
((
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
[
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
...
...
@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
if
forward_mode
.
is_decode_or_idle
():
decode_wrappers
=
[]
...
...
@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
if
forward_mode
.
is_decode_or_idle
():
self
.
indices_updater_decode
.
update
(
...
...
@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
...
...
@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
...
...
@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
if
spec_info
is
None
:
bs
=
len
(
req_pool_indices
)
...
...
@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
self
.
req_to_token
.
shape
[
1
],
)
else
:
assert
isinstance
(
spec_info
,
EagleDraftInput
)
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
...
...
@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
...
...
@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
...
...
@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
spec_info
:
Optional
[
SpecInfo
],
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]
],
):
bs
=
len
(
r
eq_
pool_indice
s
)
bs
=
len
(
s
eq_
len
s
)
if
spec_info
is
None
:
assert
len
(
seq_lens
)
==
len
(
req_pool_indices
)
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
...
...
@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
else
:
assert
isinstance
(
spec_info
,
EagleDraftInput
)
or
isinstance
(
spec_info
,
EagleVerifyInput
)
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
paged_kernel_lens
,
paged_kernel_lens_sum
,
self
.
req_to_token
,
)
)
...
...
@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
)
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global
global_override_indptr_cpu
class
FlashInferMultiStepDraftBackend
:
"""
Wrap multiple flashinfer attention backends as one for multiple consecutive
...
...
@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
(
...
...
@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
kv_last_page_len_buf
=
self
.
kv_last_page_len
,
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
...
@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
triton
.
next_power_of_2
(
bs
),
)
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole
=
self
.
kv_indptr
[:,
:
bs
+
1
].
cpu
()
global
global_override_indptr_cpu
for
i
in
range
(
self
.
speculative_num_steps
-
1
):
forward_batch
.
spec_info
.
kv_indptr
=
self
.
kv_indptr
[
i
,
:
bs
+
1
]
forward_batch
.
spec_info
.
kv_indices
=
kv_indices_buffer
[
i
][
:
seq_lens_sum
*
self
.
topk
+
bs
*
(
i
+
1
)
]
global_override_indptr_cpu
=
indptr_cpu_whole
[
i
]
call_fn
(
i
,
forward_batch
)
global_override_indptr_cpu
=
None
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
kv_indices
=
torch
.
zeros
(
(
...
...
@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
)
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
...
...
@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
for
i
in
range
(
self
.
speculative_num_steps
):
self
.
attn_backends
[
i
].
init_cuda_graph_state
(
max_bs
,
kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
]
...
...
@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
kv_start
+
offset
,
mask
=
mask
,
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
def
should_use_tensor_core
(
kv_cache_dtype
:
torch
.
dtype
,
num_attention_heads
:
int
,
...
...
@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
if
env_override
is
not
None
:
return
env_override
.
lower
()
==
"true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try
:
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
num_attention_heads
,
num_kv_heads
,
):
return
True
else
:
return
False
except
(
ImportError
,
AttributeError
):
pass
# Calculate GQA group size
gqa_group_size
=
num_attention_heads
//
num_kv_heads
...
...
@@ -1118,12 +1128,18 @@ def fast_decode_plan(
sm_scale
:
Optional
[
float
]
=
None
,
rope_scale
:
Optional
[
float
]
=
None
,
rope_theta
:
Optional
[
float
]
=
None
,
**
kwargs
,
non_blocking
:
bool
=
True
,
)
->
None
:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size
=
len
(
last_page_len
)
if
logits_soft_cap
is
None
:
logits_soft_cap
=
0.0
if
self
.
is_cuda_graph_enabled
:
if
batch_size
!=
self
.
_fixed_batch_size
:
raise
ValueError
(
...
...
@@ -1136,13 +1152,19 @@ def fast_decode_plan(
raise
ValueError
(
"The size of indices should be less than or equal to the allocated buffer"
)
# Skip these copies
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else
:
self
.
_paged_kv_indptr_buf
=
indptr
self
.
_paged_kv_indices_buf
=
indices
self
.
_paged_kv_last_page_len_buf
=
last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if
not
q_data_type
:
q_data_type
=
data_type
if
not
hasattr
(
self
,
"empty_q_data"
):
self
.
empty_q_data
=
torch
.
empty
(
0
,
...
...
@@ -1159,6 +1181,7 @@ def fast_decode_plan(
),
)
self
.
last_page_len
=
torch
.
ones
(
32768
,
dtype
=
torch
.
int32
)
empty_q_data
=
self
.
empty_q_data
empty_kv_cache
=
self
.
empty_kv_cache
stream
=
torch
.
cuda
.
current_stream
()
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
d3d4d767
...
...
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
spec_info
.
generate_attn_arg_prefill
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
self
.
req_to_token
,
)
)
...
...
python/sglang/srt/managers/cache_controller.py
View file @
d3d4d767
...
...
@@ -22,7 +22,7 @@ from typing import List, Optional
import
torch
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
M
L
ATokenToKVPoolHost
from
sglang.srt.mem_cache.memory_pool
import
MHATokenToKVPool
,
M
H
ATokenToKVPoolHost
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -128,7 +128,7 @@ class HiCacheController:
def
__init__
(
self
,
mem_pool_device
:
MHATokenToKVPool
,
mem_pool_host
:
M
L
ATokenToKVPoolHost
,
mem_pool_host
:
M
H
ATokenToKVPoolHost
,
write_policy
:
str
=
"write_through_selective"
,
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
d3d4d767
...
...
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
if
TYPE_CHECKING
:
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
...
...
@@ -523,7 +521,7 @@ class ScheduleBatch:
# Request, memory pool, and cache
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
=
None
token_to_kv_pool
:
Base
TokenToKVPool
=
None
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
=
None
tree_cache
:
BasePrefixCache
=
None
# Batch configs
...
...
@@ -596,7 +594,7 @@ class ScheduleBatch:
cls
,
reqs
:
List
[
Req
],
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
ReqToTokenPool
,
token_to_kv_pool
_allocator
:
TokenToKVPoolAllocator
,
tree_cache
:
BasePrefixCache
,
model_config
:
ModelConfig
,
enable_overlap
:
bool
,
...
...
@@ -606,7 +604,7 @@ class ScheduleBatch:
return
cls
(
reqs
=
reqs
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool
=
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
,
tree_cache
=
tree_cache
,
model_config
=
model_config
,
enable_overlap
=
enable_overlap
,
...
...
@@ -637,19 +635,19 @@ class ScheduleBatch:
return
req_pool_indices
def
alloc_token_slots
(
self
,
num_tokens
:
int
):
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool
_allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
evict
(
num_tokens
,
self
.
token_to_kv_pool
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
num_tokens
)
self
.
tree_cache
.
evict
(
num_tokens
,
self
.
token_to_kv_pool
_allocator
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
_allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
phase_str
=
"Prefill"
if
self
.
forward_mode
.
is_extend
()
else
"Decode"
logger
.
error
(
f
"
{
phase_str
}
out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
num_tokens
}
tokens.
\n
"
f
"Avaliable tokens:
{
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
}
\n
"
f
"Avaliable tokens:
{
self
.
token_to_kv_pool
_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
}
\n
"
)
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
pretty_print
()
...
...
@@ -917,12 +915,12 @@ class ScheduleBatch:
def
check_decode_mem
(
self
,
buf_multiplier
=
1
):
bs
=
len
(
self
.
reqs
)
*
buf_multiplier
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
_allocator
.
available_size
()
>=
bs
:
return
True
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
_allocator
.
free
)
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
if
self
.
token_to_kv_pool
_allocator
.
available_size
()
>=
bs
:
return
True
return
False
...
...
@@ -945,6 +943,10 @@ class ScheduleBatch:
reverse
=
True
,
)
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
first_iter
=
True
def
get_required_tokens
(
num_reqs
:
int
):
headroom_for_spec_decode
=
0
if
server_args
.
speculative_algorithm
:
...
...
@@ -958,18 +960,15 @@ class ScheduleBatch:
num_reqs
*
global_config
.
retract_decode_steps
+
headroom_for_spec_decode
)
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
first_iter
=
True
while
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
_allocator
.
available_size
()
<
get_required_tokens
(
len
(
sorted_indices
))
or
first_iter
):
if
len
(
sorted_indices
)
==
1
:
# Corner case: only one request left
assert
(
self
.
token_to_kv_pool
.
available_size
()
>
0
self
.
token_to_kv_pool
_allocator
.
available_size
()
>
0
),
"No space left for only one request"
break
...
...
@@ -983,7 +982,7 @@ class ScheduleBatch:
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
seq_lens_cpu
[
idx
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
del
self
.
tree_cache
.
entries
[
req
.
rid
]
else
:
...
...
@@ -992,7 +991,7 @@ class ScheduleBatch:
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
last_uncached_pos
:
seq_lens_cpu
[
idx
]
]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
token_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
# release the last node
...
...
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size
=
(
len
(
sorted_indices
)
*
global_config
.
retract_decode_steps
-
self
.
token_to_kv_pool
.
available_size
()
-
self
.
token_to_kv_pool
_allocator
.
available_size
()
)
residual_size
=
max
(
0
,
residual_size
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
residual_size
,
self
.
token_to_kv_pool_allocator
.
free
)
req
.
reset_for_retract
()
self
.
filter_batch
(
keep_indices
=
sorted_indices
)
...
...
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
if
self
.
spec_info
:
self
.
spec_info
.
merge_batch
(
other
.
spec_info
)
def
get_model_worker_batch
(
self
):
def
get_model_worker_batch
(
self
)
->
ModelWorkerBatch
:
if
self
.
forward_mode
.
is_decode_or_idle
():
extend_seq_lens
=
extend_prefix_lens
=
extend_logprob_start_lens
=
None
else
:
...
...
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
req_pool_indices
:
torch
.
Tensor
# The sequence length
seq_lens
:
torch
.
Tensor
# The indices of output tokens in the token_to_kv_pool
# The indices of output tokens in the token_to_kv_pool
_allocator
out_cache_loc
:
torch
.
Tensor
# The sum of all sequence lengths
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
d3d4d767
...
...
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
import
torch
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
(
Req
,
ScheduleBatch
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
Base
TokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPool
Allocator
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
...
...
@@ -75,7 +79,7 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
self
.
waiting_queue_radix_tree
=
RadixCache
(
req_to_token_pool
=
None
,
token_to_kv_pool
=
None
,
disable
=
False
req_to_token_pool
=
None
,
token_to_kv_pool
_allocator
=
None
,
disable
=
False
)
def
calc_priority
(
self
,
waiting_queue
:
List
[
Req
])
->
bool
:
...
...
@@ -251,7 +255,7 @@ class PrefillAdder:
def
__init__
(
self
,
tree_cache
:
BasePrefixCache
,
token_to_kv_pool
:
Base
TokenToKVPool
,
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
,
running_batch
:
ScheduleBatch
,
new_token_ratio
:
float
,
rem_input_tokens
:
int
,
...
...
@@ -259,7 +263,7 @@ class PrefillAdder:
mixed_with_decode_tokens
:
int
=
0
,
):
self
.
tree_cache
=
tree_cache
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
self
.
running_batch
=
running_batch
self
.
new_token_ratio
=
new_token_ratio
self
.
rem_input_tokens
=
rem_input_tokens
-
mixed_with_decode_tokens
...
...
@@ -291,7 +295,7 @@ class PrefillAdder:
@
property
def
rem_total_tokens
(
self
):
return
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
-
self
.
rem_total_token_offset
)
...
...
@@ -299,7 +303,7 @@ class PrefillAdder:
@
property
def
cur_rem_tokens
(
self
):
return
(
self
.
token_to_kv_pool
.
available_size
()
self
.
token_to_kv_pool
_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
-
self
.
cur_rem_token_offset
)
...
...
@@ -332,7 +336,6 @@ class PrefillAdder:
req
.
extend_input_len
=
min
(
req
.
extend_input_len
,
self
.
rem_chunk_tokens
)
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
req
.
extend_input_len
]
self
.
can_run_list
.
append
(
req
)
self
.
_prefill_one_req
(
0
,
req
.
extend_input_len
,
...
...
@@ -400,8 +403,8 @@ class PrefillAdder:
tokens_freed
+=
tokens_occupied
if
(
self
.
rem_chunk_tokens
is
None
or
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
self
.
rem_chunk_tokens
is
None
# chunked prefill is disabled
or
req
.
extend_input_len
<=
self
.
rem_chunk_tokens
# it is the last chunk
):
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
...
...
@@ -411,10 +414,11 @@ class PrefillAdder:
min
(
req
.
sampling_params
.
max_new_tokens
,
CLIP_MAX_NEW_TOKENS_ESTIMATION
),
)
else
:
if
self
.
rem_chunk_tokens
==
0
:
return
AddReqResult
.
OTHER
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
trunc_len
]
...
...
@@ -457,10 +461,11 @@ class PrefillAdder:
),
)
else
:
if
self
.
rem_chunk_tokens
==
0
:
return
AddReqResult
.
OTHER
# Chunked prefill
trunc_len
=
self
.
rem_chunk_tokens
if
trunc_len
==
0
:
return
AddReqResult
.
OTHER
req
.
extend_input_len
=
trunc_len
req
.
fill_ids
=
req
.
fill_ids
[:
len
(
req
.
prefix_indices
)
+
trunc_len
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
d3d4d767
...
...
@@ -164,7 +164,7 @@ class Scheduler:
self
.
server_args
.
speculative_num_draft_tokens
+
(
self
.
server_args
.
speculative_eagle_topk
*
self
.
server_args
.
speculative_num_
step
s
*
self
.
server_args
.
speculative_num_
draft_token
s
)
)
if
not
self
.
spec_algorithm
.
is_none
()
...
...
@@ -309,7 +309,9 @@ class Scheduler:
)
# Init memory pool and cache
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
=
self
.
tp_worker
.
get_memory_pool
()
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
=
(
self
.
tp_worker
.
get_memory_pool
()
)
if
(
server_args
.
chunked_prefill_size
is
not
None
...
...
@@ -317,18 +319,18 @@ class Scheduler:
):
self
.
tree_cache
=
ChunkCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
self
.
token_to_kv_pool
_allocator
,
)
else
:
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
self
.
token_to_kv_pool
_allocator
,
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
token_to_kv_pool
,
token_to_kv_pool
_allocator
=
self
.
token_to_kv_pool
_allocator
,
disable
=
server_args
.
disable_radix_cache
,
)
...
...
@@ -458,7 +460,6 @@ class Scheduler:
(
ResumeMemoryOccupationReqInput
,
self
.
resume_memory_occupation
),
(
ProfileReq
,
self
.
profile
),
(
GetInternalStateReq
,
self
.
get_internal_state
),
(
SetInternalStateReq
,
self
.
set_internal_state
),
]
)
...
...
@@ -809,7 +810,8 @@ class Scheduler:
running_bs
:
int
,
):
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
self
.
_largest_prefill_len
=
max
(
self
.
_largest_prefill_len
,
adder
.
log_input_tokens
...
...
@@ -844,7 +846,8 @@ class Scheduler:
self
.
num_generated_tokens
=
0
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
num_used
=
self
.
max_total_num_tokens
-
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
RECORD_STEP_TIME
:
...
...
@@ -894,7 +897,8 @@ class Scheduler:
def
check_memory
(
self
):
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
protected_size
=
self
.
tree_cache
.
protected_size
()
memory_leak
=
available_size
!=
(
...
...
@@ -999,7 +1003,7 @@ class Scheduler:
# Prefill policy
adder
=
PrefillAdder
(
self
.
tree_cache
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
_allocator
,
self
.
running_batch
,
self
.
new_token_ratio
,
self
.
max_prefill_tokens
,
...
...
@@ -1099,7 +1103,7 @@ class Scheduler:
new_batch
=
ScheduleBatch
.
init_new
(
can_run_list
,
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
_allocator
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
...
...
@@ -1143,8 +1147,6 @@ class Scheduler:
retracted_reqs
,
new_token_ratio
=
batch
.
retract_decode
(
self
.
server_args
)
self
.
new_token_ratio
=
new_token_ratio
if
self
.
draft_worker
:
self
.
draft_worker
.
finish_request
(
retracted_reqs
)
logger
.
info
(
"Decode out of memory happened. "
...
...
@@ -1184,11 +1186,12 @@ class Scheduler:
logits_output
,
next_token_ids
=
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
bid
=
model_worker_batch
.
bid
else
:
(
logits_output
,
next_token_ids
,
model_worker_batch
,
bid
,
num_accepted_tokens
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
spec_num_total_accepted_tokens
+=
(
...
...
@@ -1214,7 +1217,7 @@ class Scheduler:
next_token_ids
=
next_token_ids
,
extend_input_len_per_req
=
extend_input_len_per_req
,
extend_logprob_start_len_per_req
=
extend_logprob_start_len_per_req
,
bid
=
model_worker_batch
.
bid
,
bid
=
bid
,
)
else
:
# embedding or reward model
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
...
@@ -1230,6 +1233,7 @@ class Scheduler:
result
:
Union
[
GenerationBatchResult
,
EmbeddingBatchResult
],
):
if
batch
.
forward_mode
.
is_decode
():
assert
isinstance
(
result
,
GenerationBatchResult
)
self
.
process_batch_result_decode
(
batch
,
result
)
if
batch
.
is_empty
():
self
.
running_batch
=
None
...
...
@@ -1302,7 +1306,7 @@ class Scheduler:
if
self
.
is_mixed_chunk
and
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token for the mixed decode batch
j
=
len
(
batch
.
out_cache_loc
)
-
len
(
batch
.
reqs
)
+
i
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
self
.
token_to_kv_pool
_allocator
.
free
(
batch
.
out_cache_loc
[
j
:
j
+
1
])
continue
if
req
.
is_chunked
<=
0
:
...
...
@@ -1420,23 +1424,27 @@ class Scheduler:
self
.
num_generated_tokens
+=
len
(
batch
.
reqs
)
if
self
.
enable_overlap
:
assert
batch
.
spec_algorithm
.
is_none
()
logits_output
,
next_token_ids
=
self
.
tp_worker
.
resolve_batch_result
(
bid
)
next_token_logprobs
=
logits_output
.
next_token_logprobs
else
:
elif
batch
.
spec_algorithm
.
is_none
():
# spec decoding handles output logprobs inside verify process.
next_token_ids
=
next_token_ids
.
tolist
()
if
batch
.
return_logprob
:
next_token_logprobs
=
logits_output
.
next_token_logprobs
.
tolist
()
self
.
token_to_kv_pool
.
free_group_begin
()
self
.
token_to_kv_pool
_allocator
.
free_group_begin
()
# Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
req
.
is_retracted
:
continue
if
self
.
enable_overlap
and
req
.
finished
():
# Free the one delayed token
self
.
token_to_kv_pool
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
self
.
token_to_kv_pool
_allocator
.
free
(
batch
.
out_cache_loc
[
i
:
i
+
1
])
continue
if
batch
.
spec_algorithm
.
is_none
():
...
...
@@ -1479,7 +1487,7 @@ class Scheduler:
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
,
batch
.
return_logprob
)
self
.
token_to_kv_pool
.
free_group_end
()
self
.
token_to_kv_pool
_allocator
.
free_group_end
()
self
.
forward_ct_decode
=
(
self
.
forward_ct_decode
+
1
)
%
(
1
<<
30
)
if
(
...
...
@@ -1718,9 +1726,6 @@ class Scheduler:
and
not
self
.
model_config
.
is_multimodal_gen
)
):
if
self
.
draft_worker
and
req
.
finished
():
self
.
draft_worker
.
finish_request
(
req
)
rids
.
append
(
req
.
rid
)
finished_reasons
.
append
(
req
.
finished_reason
.
to_json
()
if
req
.
finished_reason
else
None
...
...
@@ -1860,7 +1865,7 @@ class Scheduler:
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
token_to_kv_pool
_allocator
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
...
...
@@ -1916,11 +1921,11 @@ class Scheduler:
if
self
.
grammar_backend
:
self
.
grammar_backend
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
self
.
token_to_kv_pool
_allocator
.
clear
()
if
not
self
.
spec_algorithm
.
is_none
():
self
.
draft_worker
.
model_runner
.
req_to_token_pool
.
clear
()
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
.
clear
()
self
.
draft_worker
.
model_runner
.
token_to_kv_pool
_allocator
.
clear
()
self
.
num_generated_tokens
=
0
self
.
forward_ct_decode
=
0
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
d3d4d767
...
...
@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput
,
ResumeMemoryOccupationReqOutput
,
SessionParams
,
SetInternalStateReq
,
SetInternalStateReqOutput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
...
...
@@ -257,9 +255,6 @@ class TokenizerManager:
self
.
get_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
set_internal_state_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
_result_dispatcher
=
TypeBasedDispatcher
(
[
...
...
@@ -309,10 +304,6 @@ class TokenizerManager:
GetInternalStateReqOutput
,
self
.
get_internal_state_communicator
.
handle_recv
,
),
(
SetInternalStateReqOutput
,
self
.
set_internal_state_communicator
.
handle_recv
,
),
(
HealthCheckOutput
,
lambda
x
:
None
),
]
)
...
...
@@ -774,14 +765,6 @@ class TokenizerManager:
)
return
res
[
0
].
internal_state
async
def
set_internal_state
(
self
,
obj
:
SetInternalStateReq
)
->
SetInternalStateReqOutput
:
res
:
List
[
SetInternalStateReqOutput
]
=
(
await
self
.
set_internal_state_communicator
(
obj
)
)
return
res
[
0
]
def
get_log_request_metadata
(
self
):
max_length
=
None
skip_names
=
None
...
...
python/sglang/srt/managers/tp_worker.py
View file @
d3d4d767
...
...
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput
,
)
from
sglang.srt.managers.schedule_batch
import
ModelWorkerBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.server_args
import
ServerArgs
...
...
@@ -49,6 +50,8 @@ class TpModelWorker:
dp_rank
:
Optional
[
int
],
nccl_port
:
int
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
# Parse args
self
.
tp_rank
=
tp_rank
...
...
@@ -77,6 +80,8 @@ class TpModelWorker:
nccl_port
=
nccl_port
,
server_args
=
server_args
,
is_draft_worker
=
is_draft_worker
,
req_to_token_pool
=
req_to_token_pool
,
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
,
)
if
server_args
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
processor
=
None
...
...
@@ -154,7 +159,7 @@ class TpModelWorker:
def
get_memory_pool
(
self
):
return
(
self
.
model_runner
.
req_to_token_pool
,
self
.
model_runner
.
token_to_kv_pool
,
self
.
model_runner
.
token_to_kv_pool
_allocator
,
)
def
forward_batch_generation
(
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
d3d4d767
...
...
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
def
get_memory_pool
(
self
):
return
(
self
.
worker
.
model_runner
.
req_to_token_pool
,
self
.
worker
.
model_runner
.
token_to_kv_pool
,
self
.
worker
.
model_runner
.
token_to_kv_pool
_allocator
,
)
def
forward_thread_func
(
self
):
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
d3d4d767
from
__future__
import
annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
class
ChunkCache
(
BasePrefixCache
):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
BaseTokenToKVPool
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
):
self
.
disable
=
True
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
self
.
entries
:
Dict
[
str
,
ChunkCacheEntry
]
=
{}
self
.
reset
()
...
...
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
req
.
req_pool_idx
,
:
token_id_len
]
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
kv_indices
)
if
req
.
rid
in
self
.
entries
:
del
self
.
entries
[
req
.
rid
]
...
...
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
def
protected_size
(
self
):
return
0
def
pretty_print
(
self
):
return
""
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
d3d4d767
...
...
@@ -7,8 +7,8 @@ import torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.mem_cache.memory_pool
import
(
Base
TokenToKVPool
,
M
L
ATokenToKVPoolHost
,
MHA
TokenToKVPool
,
M
H
ATokenToKVPoolHost
,
ReqToTokenPool
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
,
_key_match
...
...
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
Base
TokenToKVPool
,
token_to_kv_pool
:
MHA
TokenToKVPool
,
):
self
.
token_to_kv_pool_host
=
M
L
ATokenToKVPoolHost
(
token_to_kv_pool
)
self
.
token_to_kv_pool_host
=
M
H
ATokenToKVPoolHost
(
token_to_kv_pool
)
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool
,
self
.
token_to_kv_pool_host
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
d3d4d767
...
...
@@ -20,9 +20,12 @@ Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
"""
import
abc
import
logging
import
threading
from
enum
import
IntEnum
...
...
@@ -89,7 +92,7 @@ class ReqToTokenPool:
self
.
free_slots
=
list
(
range
(
self
.
size
))
class
Base
TokenToKVPool
:
class
TokenToKVPool
Allocator
:
"""A memory pool that maps a token location to its kv cache data."""
def
__init__
(
...
...
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
):
self
.
size
=
size
self
.
dtype
=
dtype
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
device
=
device
self
.
free_slots
=
None
...
...
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
self
.
is_in_free_group
=
False
self
.
free_group
=
[]
class
KVCache
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_value_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_kv_buffer
(
self
,
layer_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
set_kv_buffer
(
self
,
layer
:
RadixAttention
,
...
...
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
raise
NotImplementedError
()
class
MHATokenToKVPool
(
BaseTokenToKVPool
):
class
MHATokenToKVPool
(
KVCache
):
def
__init__
(
self
,
...
...
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device
:
str
,
enable_memory_saver
:
bool
,
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
...
...
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2
[
loc
]
=
src_2
.
to
(
dtype
).
view
(
store_dtype
)
class
MLATokenToKVPool
(
BaseTokenToKVPool
):
class
MLATokenToKVPool
(
KVCache
):
def
__init__
(
self
,
size
:
int
,
...
...
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
device
:
str
,
enable_memory_saver
:
bool
,
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
self
.
kv_lora_rank
=
kv_lora_rank
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
...
...
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
class
DoubleSparseTokenToKVPool
(
BaseTokenToKVPool
):
class
DoubleSparseTokenToKVPool
(
KVCache
):
def
__init__
(
self
,
size
:
int
,
...
...
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
heavy_channel_num
:
int
,
enable_memory_saver
:
bool
,
):
super
().
__init__
(
size
,
dtype
,
device
)
self
.
size
=
size
self
.
dtype
=
dtype
self
.
device
=
device
if
dtype
in
(
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self
.
store_dtype
=
torch
.
uint8
else
:
self
.
store_dtype
=
dtype
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
...
...
@@ -437,12 +460,12 @@ def synchronized(func):
return
wrapper
class
M
L
ATokenToKVPoolHost
:
class
M
H
ATokenToKVPoolHost
:
def
__init__
(
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
=
4
.0
,
host_to_device_ratio
:
float
=
2
.0
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
):
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
d3d4d767
...
...
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import
torch
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
TokenToKVPoolAllocator
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
...
...
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
def
__init__
(
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool
:
Base
TokenToKVPool
,
token_to_kv_pool
_allocator
:
TokenToKVPool
Allocator
,
disable
:
bool
=
False
,
):
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool
=
token_to_kv_pool
self
.
token_to_kv_pool
_allocator
=
token_to_kv_pool
_allocator
self
.
disable
=
disable
self
.
reset
()
...
...
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
token_ids_len
]
self
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
token_to_kv_pool
_allocator
.
free
(
kv_indices
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
return
...
...
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
)
# Remove req slot release the cache lock
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
...
...
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
token_ids
,
kv_indices
.
clone
())
self
.
token_to_kv_pool
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
])
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
)
# The prefix indices could be updated, reuse it
new_indices
,
new_last_node
=
self
.
match_prefix
(
token_ids
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
d3d4d767
...
...
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
TokenToKVPoolAllocator
,
)
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -98,6 +99,8 @@ class ModelRunner:
nccl_port
:
int
,
server_args
:
ServerArgs
,
is_draft_worker
:
bool
=
False
,
req_to_token_pool
:
Optional
[
ReqToTokenPool
]
=
None
,
token_to_kv_pool_allocator
:
Optional
[
TokenToKVPoolAllocator
]
=
None
,
):
# Parse args
self
.
model_config
=
model_config
...
...
@@ -115,6 +118,8 @@ class ModelRunner:
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
req_to_token_pool
=
req_to_token_pool
self
.
token_to_kv_pool_allocator
=
token_to_kv_pool_allocator
# Model-specific adjustment
if
(
...
...
@@ -257,8 +262,8 @@ class ModelRunner:
def
init_torch_distributed
(
self
):
logger
.
info
(
"Init torch distributed begin."
)
torch
.
get_device_module
(
self
.
device
).
set_device
(
self
.
gpu_id
)
if
self
.
device
==
"cuda"
:
backend
=
"nccl"
elif
self
.
device
==
"xpu"
:
...
...
@@ -660,12 +665,25 @@ class ModelRunner:
if
not
self
.
spec_algorithm
.
is_none
():
if
self
.
is_draft_worker
:
self
.
max_total_num_tokens
=
self
.
server_args
.
draft_runner_cache_size
max_num_reqs
=
self
.
server_args
.
max_num_reqs
else
:
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
# can be concurrently allocated, so we should give a headroom for it.
self
.
server_args
.
draft_runner_cache_size
=
(
self
.
max_total_num_tokens
+
max_num_reqs
*
self
.
server_args
.
speculative_num_steps
# draft
+
max_num_reqs
*
self
.
server_args
.
speculative_num_steps
*
self
.
server_args
.
speculative_eagle_topk
# verify
+
max_num_reqs
*
self
.
server_args
.
speculative_num_draft_tokens
# buffer
+
100
)
# Target worker and draft worker shares the same indices for the
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
self
.
max_total_num_tokens
=
self
.
server_args
.
draft_runner_cache_size
self
.
server_args
.
max_num_reqs
=
max_num_reqs
if
max_total_tokens
is
not
None
:
if
max_total_tokens
>
self
.
max_total_num_tokens
:
...
...
@@ -681,12 +699,25 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
if
self
.
req_to_token_pool
is
None
:
self
.
req_to_token_pool
=
ReqToTokenPool
(
size
=
max_num_reqs
+
1
,
max_context_len
=
self
.
model_config
.
context_len
+
4
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
else
:
# Draft worker shares req_to_token_pool with the target worker.
assert
self
.
is_draft_worker
if
self
.
token_to_kv_pool_allocator
is
None
:
self
.
token_to_kv_pool_allocator
=
TokenToKVPoolAllocator
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
,
)
else
:
assert
self
.
is_draft_worker
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
...
python/sglang/srt/server_args.py
View file @
d3d4d767
...
...
@@ -280,11 +280,16 @@ class ServerArgs:
self
.
disable_overlap_schedule
=
True
self
.
prefill_only_one_req
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_radix_cache
=
True
self
.
chunked_prefill_size
=
-
1
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
32
logger
.
info
(
f
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using
{
self
.
speculative_algorithm
}
speculative decoding."
"Overlap scheduler are disabled because of using "
"eagle speculative decoding."
"Max running request set to 32 because of using eagle speculative decoding."
)
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
assert
self
.
speculative_num_steps
<
self
.
speculative_num_draft_tokens
# GGUF
if
(
...
...
python/sglang/srt/speculative/build_eagle_tree.py
View file @
d3d4d767
...
...
@@ -3,14 +3,8 @@
from
typing
import
List
import
torch
from
sglang.srt.utils
import
is_cuda_available
if
is_cuda_available
():
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
(
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
,
)
from
sgl_kernel
import
build_tree_kernel
as
sgl_build_tree_kernel
from
sgl_kernel
import
build_tree_kernel_efficient
as
sgl_build_tree_kernel_efficient
def
build_tree_kernel_efficient_preprocess
(
...
...
python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
View file @
d3d4d767
...
...
@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
...
...
python/sglang/srt/speculative/eagle_utils.py
View file @
d3d4d767
from
__future__
import
annotations
import
dataclass
es
from
typing
import
TYPE_CHECKING
,
List
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
(
build_tree_kernel
,
...
...
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
@
dataclass
es
.
dataclass
@
dataclass
class
EagleDraftInput
:
# The inputs for decode
# shape: (b, topk)
...
...
@@ -46,57 +47,46 @@ class EagleDraftInput:
kv_indptr
:
torch
.
Tensor
=
None
kv_indices
:
torch
.
Tensor
=
None
# indices of unfinished requests during extend-after-decode
# e.g. [0, 2, 3, 4] if only the 1st request is finished
keep_indices
:
List
[
int
]
=
None
def
prepare_for_extend
(
self
,
batch
:
ScheduleBatch
):
req_pool_indices
=
batch
.
alloc_req_slots
(
len
(
batch
.
reqs
))
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
batch
.
out_cache_loc
=
out_cache_loc
assert
batch
.
input_ids
.
numel
()
==
batch
.
out_cache_loc
.
shape
[
0
]
# Prefill only generate 1 token.
assert
len
(
self
.
verified_id
)
==
len
(
batch
.
seq_lens
)
pt
=
0
for
i
,
req
in
enumerate
(
batch
.
reqs
):
req
.
req_pool_idx
=
req_pool_indices
[
i
]
pre_len
,
seq_len
=
len
(
req
.
prefix_indices
),
len
(
req
.
fill_ids
)
assert
seq_len
-
pre_len
==
req
.
extend_input_len
if
pre_len
>
0
:
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
:
pre_len
]
=
req
.
prefix_indices
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
pre_len
:
seq_len
]
=
(
out_cache_loc
[
pt
:
pt
+
req
.
extend_input_len
]
for
i
,
extend_len
in
enumerate
(
batch
.
extend_lens
):
input_ids
=
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
batch
.
input_ids
[
pt
:
pt
+
extend_len
]
=
torch
.
concat
(
(
input_ids
[
1
:],
self
.
verified_id
[
i
].
reshape
(
1
))
)
pt
+=
req
.
extend_input_len
# TODO: support batching inputs
assert
len
(
batch
.
extend_lens
)
==
1
batch
.
input_ids
=
torch
.
concat
((
batch
.
input_ids
[
1
:],
self
.
verified_id
))
def
prepare_extend_after_decode
(
self
,
batch
:
ScheduleBatch
,
speculative_num_steps
):
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
self
.
verified_id
.
numel
())
assert
self
.
verified_id
.
numel
()
==
batch
.
out_cache_loc
.
shape
[
0
]
accept_length_cpu
=
batch
.
spec_info
.
accept_length_cpu
batch
.
extend_lens
=
[
x
+
1
for
x
in
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
seq_lens_cpu
=
batch
.
seq_lens
.
tolist
()
assert
len
(
batch
.
req_pool_indices
)
==
len
(
batch
.
reqs
)
pt
=
0
i
=
0
for
req
in
batch
.
reqs
:
self
.
keep_indices
=
[]
for
idx
,
req
in
enumerate
(
batch
.
reqs
):
if
req
.
finished
():
continue
self
.
keep_indices
.
append
(
idx
)
# assert seq_len - pre_len == req.extend_input_len
input_len
=
batch
.
extend_lens
[
i
]
seq_len
=
seq_lens_cpu
[
i
]
batch
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][
seq_len
-
input_len
:
seq_len
]
=
batch
.
out_cache_loc
[
pt
:
pt
+
input_len
]
pt
+=
input_len
i
+=
1
assert
pt
==
batch
.
out_cache_loc
.
shape
[
0
]
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
long
)
self
.
positions
=
torch
.
empty_like
(
self
.
verified_id
,
dtype
=
torch
.
long
)
new_verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
self
.
accept_length
.
add_
(
1
)
create_extend_spec_info
[(
self
.
accept_length
.
numel
(),)](
...
...
@@ -117,14 +107,22 @@ class EagleDraftInput:
self
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
):
bs
=
self
.
accept_length
.
numel
()
keep_indices
=
torch
.
tensor
(
self
.
keep_indices
,
device
=
req_pool_indices
.
device
)
req_pool_indices
=
req_pool_indices
[
keep_indices
]
assert
req_pool_indices
.
shape
[
0
]
==
bs
assert
req_pool_indices
.
shape
[
0
]
==
paged_kernel_lens
.
shape
[
0
]
qo_indptr
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
accept_length
,
dim
=
0
)
cum_kv_seq_len
=
torch
.
zeros
((
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
kv_indices
=
torch
.
empty
(
cum_kv_seq_len
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
...
...
@@ -162,7 +160,21 @@ class EagleDraftInput:
self
.
topk_index
=
torch
.
cat
([
self
.
topk_index
,
spec_info
.
topk_index
])
@
dataclasses
.
dataclass
@
dataclass
class
EagleVerifyOutput
:
# Draft input batch
draft_input
:
EagleDraftInput
# Logit outputs from target worker
logits_output
:
LogitsProcessorOutput
# Accepeted token ids including the bonus token
verified_id
:
torch
.
Tensor
# Accepeted token length per sequence in a batch in CPU.
accept_length_per_req_cpu
:
List
[
int
]
# Accepeted indices from logits_output.next_token_logits
accepeted_indices_cpu
:
List
[
int
]
@
dataclass
class
EagleVerifyInput
:
draft_token
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
...
...
@@ -267,6 +279,7 @@ class EagleVerifyInput:
self
,
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
req_to_token
:
torch
.
Tensor
,
):
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -285,7 +298,11 @@ class EagleVerifyInput:
paged_kernel_lens
=
paged_kernel_lens
+
self
.
draft_token_num
cum_kv_seq_len
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
cum_kv_seq_len
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
+
self
.
draft_token_num
*
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
create_flashinfer_kv_indices_triton
[(
batch_size
,)](
req_to_token
,
...
...
@@ -298,7 +315,21 @@ class EagleVerifyInput:
)
return
kv_indices
,
cum_kv_seq_len
,
qo_indptr
,
self
.
custom_mask
def
verify
(
self
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
verify
(
self
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
)
->
torch
.
Tensor
:
"""WARNING: This API in-place modifies the states of logits_output
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
This API updates values inside logits_output based on the accepted
tokens. I.e., logits_output.next_token_logits only contains
accepeted token logits.
"""
draft_token
=
torch
.
cat
(
[
self
.
draft_token
,
torch
.
full
([
1
],
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)],
dim
=-
1
,
...
...
@@ -367,7 +398,6 @@ class EagleVerifyInput:
new_accept_index
=
[]
unfinished_index
=
[]
finished_extend_len
=
{}
# {rid:accept_length + 1}
accept_index_cpu
=
accept_index
.
tolist
()
predict_cpu
=
predict
.
tolist
()
has_finished
=
False
...
...
@@ -382,7 +412,6 @@ class EagleVerifyInput:
id
=
predict_cpu
[
idx
]
# if not found_finished:
req
.
output_ids
.
append
(
id
)
finished_extend_len
[
req
.
rid
]
=
j
+
1
req
.
check_finished
()
if
req
.
finished
():
has_finished
=
True
...
...
@@ -400,11 +429,10 @@ class EagleVerifyInput:
accept_index
=
accept_index
[
accept_index
!=
-
1
]
accept_length_cpu
=
accept_length
.
tolist
()
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
batch
.
token_to_kv_pool
.
free
(
mem_need_free_idx
)
token_to_kv_pool
_allocator
.
free
(
mem_need_free_idx
)
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
...
...
@@ -427,20 +455,16 @@ class EagleVerifyInput:
]
if
has_finished
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
]
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
[
unfinished_index
]
else
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
draft_input
.
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
accept_index
]
return
(
draft_input
,
logits_output
,
verified_id
,
finished_extend_len
,
accept_length_cpu
,
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
new_accept_index
]
return
EagleVerifyOutput
(
draft_input
=
draft_input
,
logits_output
=
logits_output
,
verified_id
=
verified_id
,
accept_length_per_req_cpu
=
accept_length_cpu
,
accepeted_indices_cpu
=
accept_index
,
)
...
...
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
draft_token_num
:
tl
.
constexpr
,
max_len_upper
:
tl
.
constexpr
,
):
"""
Args:
retrive_index: Pointer to indices of draft tokens
accept_mask: Mask indicating which tokens were accepted
retrive_cum_len: Cumulative lengths of token sequences in a batch
accept_index (out): Accept token indices
accept_length (out): Length of accepted tokens per sequence in a batch
extract_index (out): Index for last accepted tokens
max_len: Maximum length in a batch
draft_token_num: Number of tokens speculatively generated
max_len_upper An upper bound for token sequence length
"""
pid
=
tl
.
program_id
(
axis
=
0
)
retrive_end
=
tl
.
load
(
retrive_cum_len
+
pid
+
1
)
...
...
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
@
torch
.
compile
@
torch
.
compile
(
dynamic
=
True
)
def
select_top_k_tokens
(
i
:
int
,
topk_p
:
torch
.
Tensor
,
...
...
@@ -671,13 +707,11 @@ def select_top_k_tokens(
.
unsqueeze
(
0
)
.
repeat
(
topk_p
.
shape
[
0
],
1
),
# shape: (b, topk + 1)
)
else
:
# The later decode steps
expand_scores
=
torch
.
mul
(
scores
.
unsqueeze
(
2
),
topk_p
.
reshape
(
-
1
,
topk
,
topk
)
)
# (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p
,
topk_cs_index
=
fast_topk
(
expand_scores
.
flatten
(
start_dim
=
1
),
topk
,
dim
=-
1
)
# (b, topk)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
d3d4d767
import
logging
import
os
import
time
from
typing
import
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
huggingface_hub
import
snapshot_download
...
...
@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from
sglang.srt.speculative.eagle_utils
import
(
EagleDraftInput
,
EagleVerifyInput
,
EagleVerifyOutput
,
assign_draft_cache_locs
,
fast_topk
,
select_top_k_tokens
,
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
get_available_gpu_memory
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
nccl_port
:
int
,
target_worker
:
TpModelWorker
,
):
# Override context length with target model's context length
server_args
.
context_length
=
target_worker
.
model_runner
.
model_config
.
context_len
os
.
environ
[
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"
]
=
"1"
# Do not capture cuda graph in `super().__init__()`
# We will capture it later
backup_disable_cuda_graph
=
server_args
.
disable_cuda_graph
server_args
.
disable_cuda_graph
=
True
# Lo
ad
hot token
id
s
# Lo
ssy optimization by using
hot tokens
if
server_args
.
speculative_token_map
is
not
None
:
self
.
hot_token_id
=
load_token_map
(
server_args
.
speculative_token_map
)
server_args
.
json_model_override_args
=
(
...
...
@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
else
:
self
.
hot_token_id
=
None
# We share the allocator with a target worker. Draft/target worker
# owns its own KV cache.
self
.
req_to_token_pool
,
self
.
token_to_kv_pool_allocator
=
(
target_worker
.
get_memory_pool
()
)
# Init target worker
super
().
__init__
(
gpu_id
=
gpu_id
,
...
...
@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
nccl_port
=
nccl_port
,
dp_rank
=
dp_rank
,
is_draft_worker
=
True
,
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
)
self
.
target_worker
=
target_worker
self
.
finish_extend_len
=
[]
# Parse arguments
self
.
topk
=
server_args
.
speculative_eagle_topk
...
...
@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
server_args
.
speculative_algorithm
)
self
.
server_args
=
server_args
self
.
use_nan_detection
=
self
.
server_args
.
enable_nan_detection
self
.
device
=
self
.
model_runner
.
device
self
.
gpu_id
=
self
.
model_runner
.
gpu_id
# Share the embedding and lm_head
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
...
...
@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
head
=
head
.
clone
()
self
.
hot_token_id
=
self
.
hot_token_id
.
to
(
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
self
.
draft_model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
draft_model_runner
.
server_args
.
disable_cuda_graph
=
(
backup_disable_cuda_graph
)
# Create multi-step attn backends and cuda graph runners
if
server_args
.
attention_backend
==
"flashinfer"
:
...
...
@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
f
"EAGLE is not supportted in attention backend
{
server_args
.
attention_backend
}
"
)
self
.
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
draft_
model_runner
.
draft_attn_backend
=
self
.
draft_attn_backend
self
.
init_cuda_graphs
()
def
init_cuda_graphs
(
self
):
...
...
@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
return
tic
=
time
.
time
()
logger
.
info
(
"Capture cuda graph begin. This can take up to several minutes."
)
logger
.
info
(
f
"Capture draft cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
self
.
cuda_graph_runner
=
EAGLEDraftCudaGraphRunner
(
self
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s"
)
logger
.
info
(
f
"Capture draft cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
@
property
def
draft_model_runner
(
self
):
return
self
.
model_runner
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
,
int
]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch doesn't have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
Returns:
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert
not
batch
.
spec_algorithm
.
is_none
()
if
batch
.
forward_mode
.
is_decode
():
# Draft
spec_info
:
EagleVerifyInput
=
self
.
draft
(
batch
)
# Verify
(
next_draft_input
,
logits_output
,
verified_id
,
self
.
finish_extend_len
,
accept_length_cpu
,
model_worker_batch
,
)
=
self
.
verify
(
batch
,
spec_info
)
batch
.
spec_info
=
next_draft_input
# if it is None, means all requsets are finished
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
# if it is None, means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
self
.
forward_draft_extend_after_decode
(
batch
)
return
(
logits_output
,
verified_id
,
model_worker_batch
,
sum
(
accept_length_cpu
),
verify_output
.
verified_id
,
model_worker_batch
.
bid
,
sum
(
verify_output
.
accept_length_
per_req_
cpu
),
)
else
:
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
# Forward with the draft model.
batch
.
spec_info
=
EagleDraftInput
(
hidden_states
=
logits_output
.
hidden_states
,
verified_id
=
next_token_ids
,
logits_output
,
next_token_ids
,
bid
=
self
.
forward_target_extend
(
batch
)
self
.
forward_draft_extend
(
batch
,
logits_output
.
hidden_states
,
next_token_ids
)
self
.
forward_draft_extend
(
batch
)
return
logits_output
,
next_token_ids
,
model_worker_batch
,
0
return
logits_output
,
next_token_ids
,
bid
,
0
def
forward_target_extend
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
,
List
[
int
],
int
]:
"""Run the target extend.
Args:
batch: The batch to run. States could be modified.
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
.
capture_hidden_mode
=
CaptureHiddenMode
.
FULL
logits_output
,
next_token_ids
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
)
return
logits_output
,
next_token_ids
,
model_worker_batch
.
bid
def
draft
(
self
,
batch
:
ScheduleBatch
):
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
# Parse args
num_seqs
=
batch
.
batch_size
()
spec_info
=
batch
.
spec_info
...
...
@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
speculative_num_steps
,
)
batch
.
out_cache_loc
=
out_cache_loc
batch
.
seq_lens_sum
=
torch
.
sum
(
batch
.
seq_lens
).
item
()
spec_info
.
positions
=
batch
.
seq_lens
.
repeat_interleave
(
self
.
topk
,
dim
=
0
)
...
...
@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
# Get forward batch
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
can_cuda_graph
=
self
.
cuda_graph_runner
and
self
.
cuda_graph_runner
.
can_run
(
forward_batch
)
if
can_cuda_graph
:
score_list
,
token_list
,
parents_list
=
self
.
cuda_graph_runner
.
replay
(
forward_batch
...
...
@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
else
:
# Initialize attention backend
self
.
draft_attn_backend
.
init_forward_metadata
(
forward_batch
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
# Run forward steps
score_list
,
token_list
,
parents_list
=
self
.
draft_forward
(
forward_batch
)
...
...
@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
batch
.
sampling_info
.
is_all_greedy
,
)
# Free cache locations
batch
.
token_to_kv_pool
.
free
(
out_cache_loc
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
return
ret
return
ret
,
out_cache_loc
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
# Parse args
...
...
@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
logits_output
=
self
.
model_runner
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
topk_p
,
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
if
self
.
hot_token_id
is
not
None
:
...
...
@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
logits_output
,
_
=
self
.
target_worker
.
forward_batch_generation
(
model_worker_batch
,
skip_sample
=
True
)
self
.
_detect_nan_if_needed
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
res
=
spec_info
.
verify
(
batch
,
logits_output
)
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
batch
,
logits_output
,
self
.
token_to_kv_pool_allocator
)
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output
.
next_token_logits
=
logits_output
.
next_token_logits
[
res
.
accepeted_indices_cpu
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepeted_indices_cpu
]
# Prepare the batch for the next draft forwards.
batch
.
forward_mode
=
ForwardMode
.
DECODE
return
res
+
(
model_worker_batch
,)
batch
.
spec_info
=
res
.
draft_input
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
):
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
return
logits_output
,
res
,
model_worker_batch
def
forward_draft_extend
(
self
,
batch
:
ScheduleBatch
,
hidden_states
:
torch
.
Tensor
,
next_token_ids
:
List
[
int
],
):
"""Run draft model extend. This API modifies the states of the batch.
Args:
batch: The batch to run.
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
batch
.
spec_info
=
EagleDraftInput
(
hidden_states
=
hidden_states
,
verified_id
=
next_token_ids
,
)
batch
.
spec_info
.
prepare_for_extend
(
batch
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
def
_set_mem_pool
(
self
,
batch
:
ScheduleBatch
,
runner
:
ModelRunner
):
batch
.
token_to_kv_pool
=
runner
.
token_to_kv_pool
batch
.
req_to_token_pool
=
runner
.
req_to_token_pool
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
def
forward_draft_extend_after_decode
(
self
,
batch
:
ScheduleBatch
):
seq_lens_backup
=
batch
.
seq_lens
req_pool_indices_backup
=
batch
.
req_pool_indices
self
.
_set_mem_pool
(
batch
,
self
.
model_runner
)
batch
.
forward_mode
=
ForwardMode
.
DRAFT_EXTEND
batch
.
spec_info
.
prepare_extend_after_decode
(
batch
,
self
.
speculative_num_steps
)
batch
.
spec_info
.
capture_hidden_mode
=
CaptureHiddenMode
.
LAST
# We don't need logprob for this extend.
model_worker_batch
=
batch
.
get_model_worker_batch
()
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
logits_output
=
self
.
model_runner
.
forward
(
forward_batch
)
self
.
capture_for_decode
(
logits_output
,
forward_batch
)
self
.
_set_mem_pool
(
batch
,
self
.
target_worker
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
draft_model_runner
)
logits_output
=
self
.
draft_model_runner
.
forward
(
forward_batch
)
self
.
_detect_nan_if_needed
(
logits_output
)
assert
forward_batch
.
spec_info
is
batch
.
spec_info
self
.
capture_for_decode
(
logits_output
,
forward_batch
.
spec_info
)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
seq_lens
=
seq_lens_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
def
capture_for_decode
(
self
,
logits_output
:
LogitsProcessorOutput
,
forward_batch
:
ForwardBatch
self
,
logits_output
:
LogitsProcessorOutput
,
draft_input
:
EagleDraftInput
):
probs
=
torch
.
softmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
spec_info
=
forward_batch
.
spec_info
spec_info
.
topk_p
,
spec_info
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
# Don't support prefix share now.
def
finish_request
(
self
,
reqs
:
Union
[
Req
,
List
[
Req
]]):
if
not
isinstance
(
reqs
,
List
):
reqs
=
[
reqs
]
for
req
in
reqs
:
if
req
.
rid
not
in
self
.
finish_extend_len
:
continue
req_len
=
(
len
(
req
.
origin_input_ids
)
+
len
(
req
.
output_ids
)
-
self
.
finish_extend_len
[
req
.
rid
]
-
1
)
kv_indices
=
self
.
model_runner
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
][:
req_len
]
self
.
model_runner
.
token_to_kv_pool
.
free
(
kv_indices
)
self
.
model_runner
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
draft_input
.
topk_p
,
draft_input
.
topk_index
=
fast_topk
(
probs
,
self
.
topk
,
dim
=-
1
)
draft_input
.
hidden_states
=
logits_output
.
hidden_states
def
_detect_nan_if_needed
(
self
,
logits_output
:
LogitsProcessorOutput
):
if
self
.
use_nan_detection
:
logits
=
logits_output
.
next_token_logits
if
torch
.
any
(
torch
.
isnan
(
logits
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the logits."
)
raise
ValueError
(
"Detected errors during sampling! NaN in the logits."
)
def
load_token_map
(
token_map_path
:
str
)
->
List
[
int
]:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment