Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b26bc86b
Unverified
Commit
b26bc86b
authored
Mar 30, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 30, 2025
Browse files
Support page size > 1 + eagle (#4908)
parent
5ec5eaf7
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
374 additions
and
71 deletions
+374
-71
python/pyproject.toml
python/pyproject.toml
+1
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+5
-7
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+27
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+6
-0
python/sglang/srt/mem_cache/paged_allocator.py
python/sglang/srt/mem_cache/paged_allocator.py
+15
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+6
-4
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+52
-4
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+140
-28
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+65
-11
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+11
-4
scripts/ci_install_dependency.sh
scripts/ci_install_dependency.sh
+1
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-1
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+40
-3
test/srt/test_mla_flashinfer.py
test/srt/test_mla_flashinfer.py
+1
-0
No files found.
python/pyproject.toml
View file @
b26bc86b
...
@@ -33,6 +33,7 @@ runtime_common = [
...
@@ -33,6 +33,7 @@ runtime_common = [
"prometheus-client>=0.20.0"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"psutil"
,
"pydantic"
,
"pydantic"
,
"pynvml"
,
"python-multipart"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"pyzmq>=25.1.2"
,
"soundfile==0.13.1"
,
"soundfile==0.13.1"
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
b26bc86b
...
@@ -14,7 +14,6 @@ from functools import partial
...
@@ -14,7 +14,6 @@ from functools import partial
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
import
torch
import
torch
import
triton
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
...
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
...
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
get_bool_env_var
,
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
,
next_power_of_2
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
self
.
topk
=
topk
self
.
topk
=
topk
self
.
speculative_num_steps
=
speculative_num_steps
self
.
speculative_num_steps
=
speculative_num_steps
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
self
.
generate_draft_decode_kv_indices
=
generate_draft_decode_kv_indices
self
.
page_size
=
model_runner
.
page_size
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
max_bs
=
model_runner
.
req_to_token_pool
.
size
*
self
.
topk
self
.
kv_indptr
=
torch
.
zeros
(
self
.
kv_indptr
=
torch
.
zeros
(
...
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
...
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
self
.
pool_len
,
self
.
pool_len
,
kv_indices_buffer
.
shape
[
1
],
kv_indices_buffer
.
shape
[
1
],
self
.
kv_indptr
.
shape
[
1
],
self
.
kv_indptr
.
shape
[
1
],
triton
.
next_power_of_2
(
num_seqs
),
next_power_of_2
(
num_seqs
),
triton
.
next_power_of_2
(
self
.
speculative_num_steps
),
next_power_of_2
(
self
.
speculative_num_steps
),
triton
.
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
)
)
assert
forward_batch
.
spec_info
is
not
None
assert
forward_batch
.
spec_info
is
not
None
...
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
)
)
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
assert
forward_batch
.
spec_info
is
not
None
assert
isinstance
(
forward_batch
.
spec_info
,
EagleDraftInput
)
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
=
(
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
forward_batch
.
spec_info
.
kv_indptr
.
clone
()
)
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b26bc86b
...
@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
return
req_pool_indices
return
req_pool_indices
def
alloc_token_slots
(
self
,
num_tokens
:
int
):
def
alloc_token_slots
(
self
,
num_tokens
:
int
,
backup_state
:
bool
=
False
):
if
self
.
token_to_kv_pool_allocator
.
available_size
()
<
num_tokens
:
if
self
.
token_to_kv_pool_allocator
.
available_size
()
<
num_tokens
:
if
self
.
tree_cache
is
not
None
:
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
evict
(
num_tokens
)
self
.
tree_cache
.
evict
(
num_tokens
)
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc
(
num_tokens
)
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
phase_str
=
"Prefill"
if
self
.
forward_mode
.
is_extend
()
else
"Decode"
phase_str
=
"Prefill"
if
self
.
forward_mode
.
is_extend
()
else
"Decode"
...
@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
tree_cache
.
pretty_print
()
self
.
tree_cache
.
pretty_print
()
raise
RuntimeError
(
error_msg
)
raise
RuntimeError
(
error_msg
)
return
out_cache_loc
if
backup_state
:
return
out_cache_loc
,
state
else
:
return
out_cache_loc
def
alloc_paged_token_slots_extend
(
def
alloc_paged_token_slots_extend
(
self
,
self
,
...
@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
extend_num_tokens
:
int
,
extend_num_tokens
:
int
,
backup_state
:
bool
=
False
,
):
):
if
(
if
(
self
.
token_to_kv_pool_allocator
.
available_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
...
@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
+
len
(
seq_lens
)
*
self
.
token_to_kv_pool_allocator
.
page_size
,
+
len
(
seq_lens
)
*
self
.
token_to_kv_pool_allocator
.
page_size
,
)
)
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_extend
(
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_extend
(
prefix_lens
,
seq_lens
,
last_loc
,
extend_num_tokens
prefix_lens
,
seq_lens
,
last_loc
,
extend_num_tokens
)
)
...
@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
logger
.
error
(
error_msg
)
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
raise
RuntimeError
(
error_msg
)
return
out_cache_loc
if
backup_state
:
return
out_cache_loc
,
state
else
:
return
out_cache_loc
def
alloc_paged_token_slots_decode
(
def
alloc_paged_token_slots_decode
(
self
,
self
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
backup_state
:
bool
=
False
,
):
):
if
(
if
(
self
.
token_to_kv_pool_allocator
.
available_size
()
self
.
token_to_kv_pool_allocator
.
available_size
()
...
@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
tree_cache
.
evict
(
self
.
tree_cache
.
evict
(
len
(
seq_lens
)
*
self
.
token_to_kv_pool_allocator
.
page_size
,
len
(
seq_lens
)
*
self
.
token_to_kv_pool_allocator
.
page_size
,
)
)
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_decode
(
seq_lens
,
last_loc
)
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_decode
(
seq_lens
,
last_loc
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
error_msg
=
(
error_msg
=
(
f
"Decode out of memory. Try to lower your batch size.
\n
"
f
"Decode out of memory. Try to lower your batch size.
\n
"
...
@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
logger
.
error
(
error_msg
)
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
raise
RuntimeError
(
error_msg
)
return
out_cache_loc
if
backup_state
:
return
out_cache_loc
,
state
else
:
return
out_cache_loc
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_lens_cpu
=
[]
...
...
python/sglang/srt/managers/scheduler.py
View file @
b26bc86b
...
@@ -1110,7 +1110,7 @@ class Scheduler(
...
@@ -1110,7 +1110,7 @@ class Scheduler(
)
)
if
memory_leak
:
if
memory_leak
:
msg
=
(
msg
=
(
"
KV cache pool
leak detected! "
"
token_to_kv_pool_allocator memory
leak detected! "
f
"
{
available_size
=
}
,
{
protected_size
=
}
,
{
self
.
max_total_num_tokens
=
}
\n
"
f
"
{
available_size
=
}
,
{
protected_size
=
}
,
{
self
.
max_total_num_tokens
=
}
\n
"
f
"
{
self
.
token_to_kv_pool_allocator
.
available_size
()
=
}
\n
"
f
"
{
self
.
token_to_kv_pool_allocator
.
available_size
()
=
}
\n
"
f
"
{
self
.
tree_cache
.
evictable_size
()
=
}
\n
"
f
"
{
self
.
tree_cache
.
evictable_size
()
=
}
\n
"
...
@@ -1121,7 +1121,7 @@ class Scheduler(
...
@@ -1121,7 +1121,7 @@ class Scheduler(
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
if
len
(
self
.
req_to_token_pool
.
free_slots
)
!=
self
.
req_to_token_pool
.
size
:
msg
=
(
msg
=
(
"
Memory pool
leak detected!"
"
req_to_token_pool memory
leak detected!"
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"available_size=
{
len
(
self
.
req_to_token_pool
.
free_slots
)
}
, "
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
f
"total_size=
{
self
.
req_to_token_pool
.
size
}
\n
"
)
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
b26bc86b
...
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
...
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
if
self
.
free_group
:
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
backup_state
(
self
):
return
self
.
free_slots
def
restore_state
(
self
,
free_slots
):
self
.
free_slots
=
free_slots
def
clear
(
self
):
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
torch
.
arange
(
self
.
free_slots
=
torch
.
arange
(
...
...
python/sglang/srt/mem_cache/paged_allocator.py
View file @
b26bc86b
...
@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
...
@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
next_power_of_2
(
extend_num_tokens
),
next_power_of_2
(
extend_num_tokens
),
)
)
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
merged_value
=
self
.
ret_values
.
item
()
merged_value
=
self
.
ret_values
.
item
()
num_new_pages
=
merged_value
>>
32
num_new_pages
=
merged_value
>>
32
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
...
@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
...
@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
self
.
page_size
,
self
.
page_size
,
)
)
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
num_new_pages
=
self
.
ret_values
.
item
()
num_new_pages
=
self
.
ret_values
.
item
()
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
return
None
return
None
...
@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
...
@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
else
:
else
:
self
.
free_group
.
append
(
free_index
)
self
.
free_group
.
append
(
free_index
)
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
self
.
free_pages
))
==
len
(
self
.
free_pages
)
def
free_group_begin
(
self
):
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
self
.
free_group
=
[]
...
@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
...
@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
if
self
.
free_group
:
if
self
.
free_group
:
self
.
free
(
torch
.
cat
(
self
.
free_group
))
self
.
free
(
torch
.
cat
(
self
.
free_group
))
def
backup_state
(
self
):
return
self
.
free_pages
def
restore_state
(
self
,
free_pages
):
self
.
free_pages
=
free_pages
def
clear
(
self
):
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_pages
=
torch
.
arange
(
self
.
free_pages
=
torch
.
arange
(
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
b26bc86b
...
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if
capture_bs
is
None
:
if
capture_bs
is
None
:
if
server_args
.
speculative_algorithm
is
None
:
if
server_args
.
speculative_algorithm
is
None
:
if
server_args
.
disable_cuda_graph_padding
:
if
server_args
.
disable_cuda_graph_padding
:
capture_bs
=
list
(
range
(
1
,
33
))
+
[
64
,
96
,
128
,
16
0
]
capture_bs
=
list
(
range
(
1
,
33
))
+
range
(
40
,
161
,
16
)
else
:
else
:
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
capture_bs
=
[
1
,
2
,
4
,
8
]
+
list
(
range
(
1
6
,
161
,
8
))
else
:
else
:
# Since speculative decoding requires more cuda graph memory, we
# Since speculative decoding requires more cuda graph memory, we
# capture less.
# capture less.
capture_bs
=
list
(
range
(
1
,
9
))
+
list
(
range
(
9
,
33
,
2
))
+
[
64
,
96
,
128
,
160
]
capture_bs
=
(
list
(
range
(
1
,
9
))
+
list
(
range
(
10
,
33
,
2
))
+
list
(
range
(
40
,
161
,
16
))
)
if
_is_hip
:
if
_is_hip
:
capture_bs
+=
[
i
*
8
for
i
in
range
(
21
,
33
)]
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
...
...
python/sglang/srt/models/llama.py
View file @
b26bc86b
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
...
python/sglang/srt/server_args.py
View file @
b26bc86b
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
import
argparse
import
argparse
import
dataclasses
import
dataclasses
import
json
import
logging
import
logging
import
os
import
os
import
random
import
random
...
@@ -132,9 +133,9 @@ class ServerArgs:
...
@@ -132,9 +133,9 @@ class ServerArgs:
# Speculative decoding
# Speculative decoding
speculative_algorithm
:
Optional
[
str
]
=
None
speculative_algorithm
:
Optional
[
str
]
=
None
speculative_draft_model_path
:
Optional
[
str
]
=
None
speculative_draft_model_path
:
Optional
[
str
]
=
None
speculative_num_steps
:
int
=
5
speculative_num_steps
:
Optional
[
int
]
=
None
speculative_eagle_topk
:
int
=
4
speculative_eagle_topk
:
Optional
[
int
]
=
None
speculative_num_draft_tokens
:
int
=
8
speculative_num_draft_tokens
:
Optional
[
int
]
=
None
speculative_accept_threshold_single
:
float
=
1.0
speculative_accept_threshold_single
:
float
=
1.0
speculative_accept_threshold_acc
:
float
=
1.0
speculative_accept_threshold_acc
:
float
=
1.0
speculative_token_map
:
Optional
[
str
]
=
None
speculative_token_map
:
Optional
[
str
]
=
None
...
@@ -313,12 +314,29 @@ class ServerArgs:
...
@@ -313,12 +314,29 @@ class ServerArgs:
or
self
.
speculative_algorithm
==
"EAGLE3"
or
self
.
speculative_algorithm
==
"EAGLE3"
):
):
if
self
.
max_running_requests
is
None
:
if
self
.
max_running_requests
is
None
:
self
.
max_running_requests
=
32
self
.
max_running_requests
=
48
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
logger
.
info
(
logger
.
info
(
"Overlap scheduler is disabled because of using "
"Overlap scheduler is disabled because of using "
"eagle speculative decoding."
"eagle speculative decoding."
)
)
# Auto choose parameters
if
self
.
speculative_num_steps
is
None
:
assert
(
self
.
speculative_eagle_topk
is
None
and
self
.
speculative_num_draft_tokens
is
None
)
(
self
.
speculative_num_steps
,
self
.
speculative_eagle_topk
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
self
)
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
self
.
speculative_eagle_topk
=
1
logger
.
info
(
"speculative_eagle_topk is changed to 1 when page_size > 1"
)
# The token generated from the verify step is counted.
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
...
@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
...
@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
raise
ValueError
(
self
.
help
)
raise
ValueError
(
self
.
help
)
def
auto_choose_speculative_params
(
self
:
ServerArgs
):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
if
self
.
decrypted_config_file
:
config_path
=
self
.
decrypted_config_file
else
:
config_path
=
os
.
path
.
join
(
self
.
model_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
config_path
):
raise
ValueError
(
f
"
{
config_path
}
is not found."
)
config
=
json
.
load
(
open
(
config_path
))
arch
=
config
.
get
(
"architectures"
,
[
"Unknown"
])[
0
]
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
return
(
5
,
4
,
8
)
elif
arch
in
[
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
]:
# The default value for deepseek
return
(
5
,
4
,
8
)
elif
arch
in
[
"Grok1ForCausalLM"
,
"Grok1VForCausalLM"
]:
return
(
5
,
4
,
8
)
else
:
# The default value for all other models
return
(
5
,
4
,
8
)
python/sglang/srt/speculative/eagle_utils.py
View file @
b26bc86b
from
__future__
import
annotations
from
__future__
import
annotations
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
...
@@ -10,11 +11,15 @@ import triton.language as tl
...
@@ -10,11 +11,15 @@ import triton.language as tl
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
get_last_loc
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool
import
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.speculative.build_eagle_tree
import
build_tree_kernel_efficient
from
sglang.srt.utils
import
is_cuda_available
,
is_hip
from
sglang.srt.utils
import
is_cuda_available
,
is_hip
,
next_power_of_2
if
is_cuda_available
():
if
is_cuda_available
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -34,6 +39,9 @@ import logging
...
@@ -34,6 +39,9 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
@
dataclass
@
dataclass
class
EagleDraftInput
:
class
EagleDraftInput
:
# The inputs for decode
# The inputs for decode
...
@@ -93,7 +101,7 @@ class EagleDraftInput:
...
@@ -93,7 +101,7 @@ class EagleDraftInput:
torch
.
cumsum
(
self
.
accept_length
,
axis
=
0
,
dtype
=
torch
.
int
),
torch
.
cumsum
(
self
.
accept_length
,
axis
=
0
,
dtype
=
torch
.
int
),
self
.
positions
,
self
.
positions
,
new_verified_id
,
new_verified_id
,
triton
.
next_power_of_2
(
speculative_num_steps
+
1
),
next_power_of_2
(
speculative_num_steps
+
1
),
)
)
batch
.
seq_lens_sum
=
sum
(
seq_lens_cpu
)
batch
.
seq_lens_sum
=
sum
(
seq_lens_cpu
)
...
@@ -225,18 +233,34 @@ class EagleVerifyInput:
...
@@ -225,18 +233,34 @@ class EagleVerifyInput:
CaptureHiddenMode
.
FULL
,
CaptureHiddenMode
.
FULL
,
)
)
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
):
def
prepare_for_verify
(
self
,
batch
:
ScheduleBatch
,
page_size
:
int
):
batch
.
input_ids
=
self
.
draft_token
batch
.
input_ids
=
self
.
draft_token
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
batch
.
input_ids
.
numel
())
if
page_size
==
1
:
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
else
:
prefix_lens
=
batch
.
seq_lens
end_offset
=
prefix_lens
+
self
.
draft_token_num
last_loc
=
get_last_loc
(
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_pool_indices
,
prefix_lens
,
)
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
prefix_lens
,
end_offset
,
last_loc
,
len
(
batch
.
input_ids
)
)
self
.
last_loc
=
last_loc
bs
=
batch
.
batch_size
()
bs
=
batch
.
batch_size
()
assign_req_to_token_pool
[(
bs
,)](
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
seq_lens
,
batch
.
seq_lens
,
batch
.
seq_lens
+
self
.
draft_token_num
,
end_offset
,
batch
.
out_cache_loc
,
batch
.
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
triton
.
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
)
)
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
...
@@ -282,6 +306,7 @@ class EagleVerifyInput:
...
@@ -282,6 +306,7 @@ class EagleVerifyInput:
batch
:
ScheduleBatch
,
batch
:
ScheduleBatch
,
logits_output
:
torch
.
Tensor
,
logits_output
:
torch
.
Tensor
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
page_size
:
int
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Verify and find accepted tokens based on logits output and batch
Verify and find accepted tokens based on logits output and batch
...
@@ -305,6 +330,7 @@ class EagleVerifyInput:
...
@@ -305,6 +330,7 @@ class EagleVerifyInput:
)
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
accept_length
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# Apply penalty
if
sampling_info
.
penalizer_orchestrator
.
is_required
:
if
sampling_info
.
penalizer_orchestrator
.
is_required
:
# This is a relaxed version of penalties for speculative decoding.
# This is a relaxed version of penalties for speculative decoding.
linear_penalty
=
torch
.
zeros
(
linear_penalty
=
torch
.
zeros
(
...
@@ -317,6 +343,7 @@ class EagleVerifyInput:
...
@@ -317,6 +343,7 @@ class EagleVerifyInput:
torch
.
repeat_interleave
(
linear_penalty
,
self
.
draft_token_num
,
dim
=
0
)
torch
.
repeat_interleave
(
linear_penalty
,
self
.
draft_token_num
,
dim
=
0
)
)
)
# Sample tokens
if
batch
.
sampling_info
.
is_all_greedy
:
if
batch
.
sampling_info
.
is_all_greedy
:
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
target_predict
=
torch
.
argmax
(
logits_output
.
next_token_logits
,
dim
=-
1
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
target_predict
=
target_predict
.
reshape
(
bs
,
self
.
draft_token_num
)
...
@@ -378,13 +405,24 @@ class EagleVerifyInput:
...
@@ -378,13 +405,24 @@ class EagleVerifyInput:
deterministic
=
True
,
deterministic
=
True
,
)
)
if
SIMULATE_ACC_LEN
:
# Do simulation
accept_index
=
_generate_simulated_accept_index
(
accept_index
=
accept_index
,
predict
=
predict
,
# mutable
accept_length
=
accept_length
,
# mutable
simulate_acc_len
=
SIMULATE_ACC_LEN
,
bs
=
bs
,
spec_steps
=
self
.
spec_steps
,
)
new_accept_index
=
[]
new_accept_index
=
[]
unfinished_index
=
[]
unfinished_index
=
[]
accept_index_cpu
=
accept_index
.
tolist
()
accept_index_cpu
=
accept_index
.
tolist
()
predict_cpu
=
predict
.
tolist
()
predict_cpu
=
predict
.
tolist
()
has_finished
=
False
has_finished
=
False
#
i
terate every accepted token and check if req has finished after append the token
#
I
terate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
# should be checked BEFORE free kv cache slots
for
i
,
(
req
,
accept_index_row
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_index_cpu
)):
for
i
,
(
req
,
accept_index_row
)
in
enumerate
(
zip
(
batch
.
reqs
,
accept_index_cpu
)):
new_accept_index_
=
[]
new_accept_index_
=
[]
...
@@ -407,13 +445,28 @@ class EagleVerifyInput:
...
@@ -407,13 +445,28 @@ class EagleVerifyInput:
unfinished_index
.
append
(
i
)
unfinished_index
.
append
(
i
)
req
.
spec_verify_ct
+=
1
req
.
spec_verify_ct
+=
1
if
has_finished
:
accept_length
=
(
accept_index
!=
-
1
).
sum
(
dim
=
1
)
-
1
# Free the KV cache for unaccepted tokens
accept_index
=
accept_index
[
accept_index
!=
-
1
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
if
page_size
!=
1
:
align_evict_mask_to_page_size
[
len
(
batch
.
seq_lens
),](
batch
.
seq_lens
,
evict_mask
,
page_size
,
self
.
draft_token_num
,
next_power_of_2
(
self
.
draft_token_num
),
)
token_to_kv_pool_allocator
.
free
(
batch
.
out_cache_loc
[
evict_mask
])
# Construct EagleVerifyOutput
if
not
has_finished
:
if
not
has_finished
:
accept_index
=
accept_index
[
accept_index
!=
-
1
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
token_to_kv_pool_allocator
.
free
(
mem_need_free_idx
)
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
accept_index
]
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
accept_index
]
assign_req_to_token_pool
[(
bs
,)](
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
...
@@ -422,7 +475,7 @@ class EagleVerifyInput:
...
@@ -422,7 +475,7 @@ class EagleVerifyInput:
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
,
batch
.
out_cache_loc
,
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
triton
.
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
)
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
()
...
@@ -443,13 +496,6 @@ class EagleVerifyInput:
...
@@ -443,13 +496,6 @@ class EagleVerifyInput:
accepeted_indices
=
accept_index
,
accepeted_indices
=
accept_index
,
)
)
else
:
else
:
accept_length
=
(
accept_index
!=
-
1
).
sum
(
dim
=
1
)
-
1
accept_index
=
accept_index
[
accept_index
!=
-
1
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
mem_need_free_idx
=
batch
.
out_cache_loc
[
evict_mask
]
token_to_kv_pool_allocator
.
free
(
mem_need_free_idx
)
assign_req_to_token_pool
[(
bs
,)](
assign_req_to_token_pool
[(
bs
,)](
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
...
@@ -457,7 +503,7 @@ class EagleVerifyInput:
...
@@ -457,7 +503,7 @@ class EagleVerifyInput:
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
seq_lens
+
accept_length
+
1
,
batch
.
out_cache_loc
[
accept_index
],
batch
.
out_cache_loc
[
accept_index
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
triton
.
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
)
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
accept_length_cpu
=
accept_length
.
tolist
()
accept_length_cpu
=
accept_length
.
tolist
()
...
@@ -465,20 +511,21 @@ class EagleVerifyInput:
...
@@ -465,20 +511,21 @@ class EagleVerifyInput:
draft_input
=
EagleDraftInput
()
draft_input
=
EagleDraftInput
()
if
len
(
new_accept_index
)
>
0
:
if
len
(
new_accept_index
)
>
0
:
new_accept_index
=
torch
.
tensor
(
new_accept_index
,
device
=
"cuda"
)
new_accept_index
=
torch
.
tensor
(
new_accept_index
,
device
=
"cuda"
)
unfinished_index_device
=
torch
.
tensor
(
unfinished_index
,
device
=
"cuda"
)
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
draft_input
.
hidden_states
=
batch
.
spec_info
.
hidden_states
[
new_accept_index
new_accept_index
]
]
draft_input
.
verified_id
=
predict
[
new_accept_index
]
draft_input
.
verified_id
=
predict
[
new_accept_index
]
draft_input
.
accept_length
=
accept_length
[
unfinished_index
]
draft_input
.
accept_length_cpu
=
[
draft_input
.
accept_length_cpu
=
[
accept_length_cpu
[
i
]
for
i
in
unfinished_index
accept_length_cpu
[
i
]
for
i
in
unfinished_index
]
]
draft_input
.
accept_length
=
accept_length
[
unfinished_index_device
]
if
has_finished
:
if
has_finished
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index
unfinished_index
_device
]
]
draft_input
.
req_pool_indices_for_draft_extend
=
(
draft_input
.
req_pool_indices_for_draft_extend
=
(
batch
.
req_pool_indices
[
unfinished_index
]
batch
.
req_pool_indices
[
unfinished_index
_device
]
)
)
else
:
else
:
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
draft_input
.
seq_lens_for_draft_extend
=
batch
.
seq_lens
...
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
...
@@ -564,13 +611,24 @@ def assign_draft_cache_locs(
pool_len
:
tl
.
constexpr
,
pool_len
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
topk
:
tl
.
constexpr
,
speculative_num_steps
:
tl
.
constexpr
,
speculative_num_steps
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
):
):
BLOCK_SIZE
:
tl
.
constexpr
=
32
BLOCK_SIZE
:
tl
.
constexpr
=
32
pid
=
tl
.
program_id
(
axis
=
0
)
pid
=
tl
.
program_id
(
axis
=
0
)
kv_start
=
tl
.
load
(
seq_lens
+
pid
)
kv_start
=
tl
.
load
(
seq_lens
+
pid
)
kv_end
=
tl
.
load
(
seq_lens
+
pid
)
+
topk
*
speculative_num_steps
if
page_size
==
1
or
topk
==
1
:
kv_end
=
tl
.
load
(
seq_lens
+
pid
)
+
topk
*
speculative_num_steps
out_cache_ptr
=
out_cache_loc
+
pid
*
topk
*
speculative_num_steps
else
:
prefix_len
=
tl
.
load
(
seq_lens
+
pid
)
last_page_len
=
prefix_len
%
page_size
num_new_page
=
(
last_page_len
+
speculative_num_steps
+
page_size
-
1
)
//
page_size
kv_end
=
prefix_len
//
page_size
*
page_size
+
num_new_page
*
(
page_size
*
topk
)
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
token_pool
=
req_to_token
+
tl
.
load
(
req_pool_indices
+
pid
)
*
pool_len
out_cache_ptr
=
out_cache_loc
+
pid
*
topk
*
speculative_num_steps
num_loop
=
tl
.
cdiv
(
topk
*
speculative_num_steps
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
topk
*
speculative_num_steps
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
for
i
in
range
(
num_loop
):
...
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
...
@@ -642,6 +700,29 @@ def generate_draft_decode_kv_indices(
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
tl
.
store
(
kv_indptr
+
zid
,
base
+
zid
*
iters
)
@
triton
.
jit
def
align_evict_mask_to_page_size
(
seq_lens
,
evict_mask
,
page_size
:
tl
.
constexpr
,
num_draft_tokens
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
t_range
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
bid
=
tl
.
program_id
(
axis
=
0
)
seq_len
=
tl
.
load
(
seq_lens
+
bid
)
io_mask
=
t_range
<
num_draft_tokens
mask_row
=
tl
.
load
(
evict_mask
+
bid
*
num_draft_tokens
+
t_range
,
mask
=
io_mask
)
num_trues
=
tl
.
sum
(
mask_row
)
num_false
=
num_draft_tokens
-
num_trues
start
=
(
seq_len
+
num_false
-
1
)
//
page_size
*
page_size
-
seq_len
for
i
in
range
(
max
(
start
,
0
),
min
(
start
+
page_size
,
num_draft_tokens
)):
tl
.
store
(
evict_mask
+
bid
*
num_draft_tokens
+
i
,
False
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
)
def
select_top_k_tokens
(
def
select_top_k_tokens
(
i
:
int
,
i
:
int
,
...
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
...
@@ -699,3 +780,34 @@ def fast_topk(values, topk, dim):
else
:
else
:
# Use topk for efficiency with larger k values
# Use topk for efficiency with larger k values
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
return
torch
.
topk
(
values
,
topk
,
dim
=
dim
)
def
_generate_simulated_accept_index
(
accept_index
,
predict
,
accept_length
,
simulate_acc_len
,
bs
,
spec_steps
,
):
simulate_acc_len_float
=
float
(
simulate_acc_len
)
simulated_values
=
torch
.
normal
(
mean
=
simulate_acc_len_float
,
std
=
1.0
,
size
=
(
1
,),
device
=
"cpu"
,
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
)
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
accept_indx_first_col
=
accept_index
[:,
0
].
view
(
-
1
,
1
)
sim_accept_index
=
torch
.
full
(
(
bs
,
spec_steps
+
1
),
-
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sim_accept_index
[:,
:
simulate_acc_len
]
=
accept_indx_first_col
+
torch
.
arange
(
simulate_acc_len
,
device
=
accept_index
.
device
)
accept_length
.
fill_
(
simulate_acc_len
-
1
)
predict
.
fill_
(
100
)
# some legit token id
return
sim_accept_index
python/sglang/srt/speculative/eagle_worker.py
View file @
b26bc86b
...
@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
...
@@ -11,7 +11,7 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from
sglang.srt.layers.dp_attention
import
disable_dp_size
from
sglang.srt.layers.dp_attention
import
disable_dp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
get_last_loc
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
CaptureHiddenMode
,
...
@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -67,6 +67,7 @@ class EAGLEWorker(TpModelWorker):
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
self
.
device
=
server_args
.
device
self
.
device
=
server_args
.
device
self
.
target_worker
=
target_worker
self
.
target_worker
=
target_worker
self
.
page_size
=
server_args
.
page_size
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
...
@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
...
@@ -234,14 +235,11 @@ class EAGLEWorker(TpModelWorker):
"""
"""
if
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
spec_info
,
to_free_cache_loc
=
self
.
draft
(
batch
)
spec_info
=
self
.
draft
(
batch
)
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
logits_output
,
verify_output
,
model_worker_batch
=
self
.
verify
(
batch
,
spec_info
batch
,
spec_info
)
)
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
self
.
token_to_kv_pool_allocator
.
free
(
to_free_cache_loc
)
# If it is None, it means all requests are finished
# If it is None, it means all requests are finished
if
batch
.
spec_info
.
verified_id
is
not
None
:
if
batch
.
spec_info
.
verified_id
is
not
None
:
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
with
self
.
draft_tp_context
(
self
.
draft_model_runner
.
tp_group
):
...
@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
...
@@ -305,9 +303,59 @@ class EAGLEWorker(TpModelWorker):
)
)
# Allocate cache locations
# Allocate cache locations
out_cache_loc
=
batch
.
alloc_token_slots
(
if
self
.
page_size
==
1
:
num_seqs
*
self
.
topk
*
self
.
speculative_num_steps
out_cache_loc
,
token_to_kv_pool_state_backup
=
batch
.
alloc_token_slots
(
)
num_seqs
*
self
.
topk
*
self
.
speculative_num_steps
,
backup_state
=
True
)
else
:
if
self
.
topk
==
1
:
prefix_lens
=
batch
.
seq_lens
seq_lens
=
prefix_lens
+
self
.
speculative_num_steps
extend_num_tokens
=
num_seqs
*
self
.
speculative_num_steps
else
:
# In this case, the last partial page needs to be duplicated.
# KV cache layout in batch.req_to_token_pool.req_to_token:
#
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
# prefix top-k = 0 tok-k = 1 top-k = 2
#
# "-" means prefix tokens
# "x" means speculative draft tokens
# "." means padded tokens
# TODO: fuse these ops
prefix_lens
=
batch
.
seq_lens
last_page_lens
=
prefix_lens
%
self
.
page_size
num_new_pages
=
(
last_page_lens
+
self
.
speculative_num_steps
+
self
.
page_size
-
1
)
//
self
.
page_size
seq_lens
=
(
prefix_lens
//
self
.
page_size
*
self
.
page_size
+
num_new_pages
*
(
self
.
page_size
*
self
.
topk
)
)
extend_num_tokens
=
torch
.
sum
(
seq_lens
-
prefix_lens
).
item
()
raise
NotImplementedError
(
"page_size > 1 and top_k > 1 are not supported."
)
# TODO: Support page_size > 1 and top_k > 1
# 1. Duplicate the KV cache in the last partial page for all top-k segments
# 2. Modify generate_draft_decode_kv_indices accordingly
last_loc
=
get_last_loc
(
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_pool_indices
,
prefix_lens
,
)
out_cache_loc
,
token_to_kv_pool_state_backup
=
(
batch
.
alloc_paged_token_slots_extend
(
prefix_lens
,
seq_lens
,
last_loc
,
extend_num_tokens
,
backup_state
=
True
,
)
)
assign_draft_cache_locs
[(
num_seqs
,)](
assign_draft_cache_locs
[(
num_seqs
,)](
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
...
@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -316,6 +364,7 @@ class EAGLEWorker(TpModelWorker):
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
batch
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
self
.
topk
,
self
.
topk
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
page_size
,
)
)
batch
.
out_cache_loc
=
out_cache_loc
batch
.
out_cache_loc
=
out_cache_loc
batch
.
seq_lens_sum
=
torch
.
sum
(
batch
.
seq_lens
).
item
()
batch
.
seq_lens_sum
=
torch
.
sum
(
batch
.
seq_lens
).
item
()
...
@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -343,6 +392,8 @@ class EAGLEWorker(TpModelWorker):
# Run forward steps
# Run forward steps
score_list
,
token_list
,
parents_list
=
self
.
draft_forward
(
forward_batch
)
score_list
,
token_list
,
parents_list
=
self
.
draft_forward
(
forward_batch
)
self
.
token_to_kv_pool_allocator
.
restore_state
(
token_to_kv_pool_state_backup
)
ret
=
EagleVerifyInput
.
create
(
ret
=
EagleVerifyInput
.
create
(
spec_info
.
verified_id
,
spec_info
.
verified_id
,
score_list
,
score_list
,
...
@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -354,7 +405,7 @@ class EAGLEWorker(TpModelWorker):
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
server_args
.
speculative_num_draft_tokens
,
self
.
server_args
.
speculative_num_draft_tokens
,
)
)
return
ret
,
out_cache_loc
return
ret
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
def
draft_forward
(
self
,
forward_batch
:
ForwardBatch
):
# Parse args
# Parse args
...
@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -411,7 +462,7 @@ class EAGLEWorker(TpModelWorker):
return
score_list
,
token_list
,
parents_list
return
score_list
,
token_list
,
parents_list
def
verify
(
self
,
batch
:
ScheduleBatch
,
spec_info
:
EagleVerifyInput
):
def
verify
(
self
,
batch
:
ScheduleBatch
,
spec_info
:
EagleVerifyInput
):
spec_info
.
prepare_for_verify
(
batch
)
spec_info
.
prepare_for_verify
(
batch
,
self
.
page_size
)
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
spec_info
batch
.
spec_info
=
spec_info
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
...
@@ -421,7 +472,10 @@ class EAGLEWorker(TpModelWorker):
self
.
_detect_nan_if_needed
(
logits_output
)
self
.
_detect_nan_if_needed
(
logits_output
)
spec_info
.
hidden_states
=
logits_output
.
hidden_states
spec_info
.
hidden_states
=
logits_output
.
hidden_states
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
res
:
EagleVerifyOutput
=
spec_info
.
verify
(
batch
,
logits_output
,
self
.
token_to_kv_pool_allocator
batch
,
logits_output
,
self
.
token_to_kv_pool_allocator
,
self
.
page_size
,
)
)
# Post process based on verified outputs.
# Post process based on verified outputs.
...
...
python/sglang/test/test_utils.py
View file @
b26bc86b
...
@@ -76,11 +76,14 @@ def is_in_ci():
...
@@ -76,11 +76,14 @@ def is_in_ci():
if
is_in_ci
():
if
is_in_ci
():
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
=
5157
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
=
(
DEFAULT_URL_FOR_TEST
=
"http://127.0.0.1:6157"
5000
+
int
(
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
"0"
)[
0
])
*
100
)
else
:
else
:
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
=
1157
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
=
(
DEFAULT_URL_FOR_TEST
=
"http://127.0.0.1:2157"
7000
+
int
(
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
"0"
)[
0
])
*
100
)
DEFAULT_URL_FOR_TEST
=
f
"http://127.0.0.1:
{
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
+
1000
}
"
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
def
call_generate_lightllm
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
...
@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
...
@@ -1009,6 +1012,9 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
class
CustomTestCase
(
unittest
.
TestCase
):
class
CustomTestCase
(
unittest
.
TestCase
):
pass
"""
def _callTestMethod(self, method):
def _callTestMethod(self, method):
max_retry = int(
max_retry = int(
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
...
@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase):
...
@@ -1017,3 +1023,4 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method),
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
max_retry=max_retry,
)
)
"""
scripts/ci_install_dependency.sh
View file @
b26bc86b
...
@@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
...
@@ -18,7 +18,7 @@ pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --force-rei
pip
install
sgl-kernel
==
0.0.5.post4
--force-reinstall
pip
install
sgl-kernel
==
0.0.5.post4
--force-reinstall
pip
install
torch_memory_saver
pip
install
torch_memory_saver
pip
install
transformers
==
4.50.0 sentence_transformers
accelerate
==
1.4.0 peft pandas datasets timm
pip
install
transformers
==
4.50.0 sentence_transformers
accelerate
==
1.4.0 peft pandas datasets timm
torchaudio
# For compling xgrammar kernels
# For compling xgrammar kernels
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
pip
install
cuda-python nvidia-cuda-nvrtc-cu12
...
...
test/srt/run_suite.py
View file @
b26bc86b
...
@@ -26,7 +26,7 @@ suites = {
...
@@ -26,7 +26,7 @@ suites = {
TestFile
(
"test_abort.py"
,
51
),
TestFile
(
"test_abort.py"
,
51
),
TestFile
(
"test_block_int8.py"
,
22
),
TestFile
(
"test_block_int8.py"
,
22
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_eagle_infer.py"
,
447
),
TestFile
(
"test_eagle_infer.py"
,
500
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_ebnf_constrained.py"
),
TestFile
(
"test_fp8_kernel.py"
,
2
),
TestFile
(
"test_fp8_kernel.py"
,
2
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
TestFile
(
"test_embedding_openai_server.py"
,
36
),
...
...
test/srt/test_eagle_infer.py
View file @
b26bc86b
...
@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
...
@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
.
json
()
avg_spec_accept_length
=
server_info
.
json
()
[
"avg_spec_accept_length"
]
avg_spec_accept_length
=
server_info
[
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
3.5
)
speculative_eagle_topk
=
server_info
[
"speculative_eagle_topk"
]
if
speculative_eagle_topk
==
1
:
self
.
assertGreater
(
avg_spec_accept_length
,
2.5
)
else
:
self
.
assertGreater
(
avg_spec_accept_length
,
3.5
)
# Wait a little bit so that the memory check happens.
# Wait a little bit so that the memory check happens.
time
.
sleep
(
4
)
time
.
sleep
(
4
)
...
@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
...
@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
)
class
TestEAGLEServerPageSize
(
TestEAGLEServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
5
,
"--speculative-eagle-topk"
,
1
,
"--speculative-num-draft-tokens"
,
6
,
"--mem-fraction-static"
,
0.7
,
"--chunked-prefill-size"
,
128
,
"--max-running-requests"
,
8
,
"--page-size"
,
4
,
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
test/srt/test_mla_flashinfer.py
View file @
b26bc86b
...
@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
...
@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.60
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
server_info
=
requests
.
get
(
self
.
base_url
+
"/get_server_info"
)
print
(
f
"
{
server_info
=
}
"
)
avg_spec_accept_length
=
server_info
.
json
()[
"avg_spec_accept_length"
]
avg_spec_accept_length
=
server_info
.
json
()[
"avg_spec_accept_length"
]
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
print
(
f
"
{
avg_spec_accept_length
=
}
"
)
self
.
assertGreater
(
avg_spec_accept_length
,
2.5
)
self
.
assertGreater
(
avg_spec_accept_length
,
2.5
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment