Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d3d4d767
Unverified
Commit
d3d4d767
authored
Mar 05, 2025
by
Ying Sheng
Committed by
GitHub
Mar 05, 2025
Browse files
[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by:
Ke Bao
<
ISPObaoke@163.com
>
parent
5be8f1ed
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
5 deletions
+143
-5
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+0
-4
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+143
-1
No files found.
python/sglang/srt/speculative/spec_info.py
View file @
d3d4d767
...
@@ -20,7 +20,3 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -20,7 +20,3 @@ class SpeculativeAlgorithm(IntEnum):
if
name
is
not
None
:
if
name
is
not
None
:
name
=
name
.
upper
()
name
=
name
.
upper
()
return
name_map
[
name
]
return
name_map
[
name
]
class
SpecInfo
:
pass
test/srt/test_eagle_infer.py
View file @
d3d4d767
import
multiprocessing
as
mp
import
random
import
random
import
threading
import
threading
import
time
import
time
...
@@ -18,6 +19,8 @@ from sglang.test.test_utils import (
...
@@ -18,6 +19,8 @@ from sglang.test.test_utils import (
popen_launch_server
,
popen_launch_server
,
)
)
acc_rate_tolerance
=
0.15
class
TestEAGLEEngine
(
unittest
.
TestCase
):
class
TestEAGLEEngine
(
unittest
.
TestCase
):
BASE_CONFIG
=
{
BASE_CONFIG
=
{
...
@@ -43,13 +46,19 @@ class TestEAGLEEngine(unittest.TestCase):
...
@@ -43,13 +46,19 @@ class TestEAGLEEngine(unittest.TestCase):
configs
=
[
configs
=
[
self
.
BASE_CONFIG
,
self
.
BASE_CONFIG
,
{
**
self
.
BASE_CONFIG
,
"disable_cuda_graph"
:
True
},
{
**
self
.
BASE_CONFIG
,
"disable_cuda_graph"
:
True
},
{
**
self
.
BASE_CONFIG
,
"chunked_prefill_size"
:
2
},
]
]
for
config
in
configs
:
for
config
in
configs
:
with
self
.
subTest
(
with
self
.
subTest
(
cuda_graph
=
(
cuda_graph
=
(
"enabled"
if
len
(
config
)
==
len
(
self
.
BASE_CONFIG
)
else
"disabled"
"enabled"
if
len
(
config
)
==
len
(
self
.
BASE_CONFIG
)
else
"disabled"
)
),
chunked_prefill_size
=
(
config
[
"chunked_prefill_size"
]
if
"chunked_prefill_size"
in
config
else
"default"
),
):
):
engine
=
sgl
.
Engine
(
**
config
)
engine
=
sgl
.
Engine
(
**
config
)
try
:
try
:
...
@@ -125,6 +134,8 @@ class TestEAGLEServer(unittest.TestCase):
...
@@ -125,6 +134,8 @@ class TestEAGLEServer(unittest.TestCase):
"64"
,
"64"
,
"--mem-fraction-static"
,
"--mem-fraction-static"
,
"0.7"
,
"0.7"
,
"--chunked-prefill-size"
,
"128"
,
"--cuda-graph-max-bs"
,
"--cuda-graph-max-bs"
,
"32"
,
"32"
,
],
],
...
@@ -196,6 +207,137 @@ class TestEAGLEServer(unittest.TestCase):
...
@@ -196,6 +207,137 @@ class TestEAGLEServer(unittest.TestCase):
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
def
measure_acc_rate
(
engine
):
tic
=
time
.
time
()
prompt
=
[
"Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>
\n\n
Assistant:"
]
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
512
}
output
=
engine
.
generate
(
prompt
,
sampling_params
)
output
=
output
[
0
]
latency
=
time
.
time
()
-
tic
if
"spec_verify_ct"
in
output
[
"meta_info"
]:
base_acc_length
=
(
output
[
"meta_info"
][
"completion_tokens"
]
/
output
[
"meta_info"
][
"spec_verify_ct"
]
)
else
:
base_acc_length
=
0.0
base_speed
=
output
[
"meta_info"
][
"completion_tokens"
]
/
latency
return
base_acc_length
,
base_speed
class
TestEagleAcceptanceRate
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
ref_engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
speculative_draft_model_path
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
speculative_algorithm
=
"EAGLE"
,
speculative_num_steps
=
5
,
speculative_eagle_topk
=
8
,
speculative_num_draft_tokens
=
64
,
mem_fraction_static
=
0.7
,
disable_radix_cache
=
True
,
)
cls
.
base_acc_length
,
cls
.
base_speed
=
measure_acc_rate
(
ref_engine
)
ref_engine
.
shutdown
()
assert
cls
.
base_acc_length
>
4.45
def
test_acc_rate
(
self
):
base_acc_length
,
base_speed
=
self
.
base_acc_length
,
self
.
base_speed
chunk_engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
speculative_draft_model_path
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
speculative_algorithm
=
"EAGLE"
,
speculative_num_steps
=
5
,
speculative_eagle_topk
=
8
,
speculative_num_draft_tokens
=
64
,
mem_fraction_static
=
0.7
,
chunked_prefill_size
=
2
,
disable_radix_cache
=
True
,
)
chunked_acc_length
,
chunked_base_speed
=
measure_acc_rate
(
chunk_engine
)
chunk_engine
.
shutdown
()
print
(
base_acc_length
,
base_speed
)
print
(
chunked_acc_length
,
chunked_base_speed
)
assert
abs
(
base_acc_length
-
chunked_acc_length
)
<
acc_rate_tolerance
def
test_acc_rate_prefix_caching
(
self
):
base_acc_length
,
base_speed
=
self
.
base_acc_length
,
self
.
base_speed
prefix_caching_engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
speculative_draft_model_path
=
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
speculative_algorithm
=
"EAGLE"
,
speculative_num_steps
=
5
,
speculative_eagle_topk
=
8
,
speculative_num_draft_tokens
=
64
,
mem_fraction_static
=
0.7
,
chunked_prefill_size
=
4
,
schedule_policy
=
"lpm"
,
)
for
_
in
range
(
10
):
acc_length
,
_
=
measure_acc_rate
(
prefix_caching_engine
)
print
(
f
"
{
acc_length
=
}
"
)
assert
abs
(
base_acc_length
-
acc_length
)
<
acc_rate_tolerance
# The second one should hit the prefix cache.
prefix_caching_engine
.
shutdown
()
class
TestEAGLERetract
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--speculative-algorithm"
,
"EAGLE"
,
"--speculative-draft-model-path"
,
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"--speculative-num-steps"
,
"5"
,
"--speculative-eagle-topk"
,
"8"
,
"--speculative-num-draft-tokens"
,
"64"
,
"--mem-fraction-static"
,
"0.7"
,
"--chunked-prefill-size"
,
"128"
,
"--max-running-requests"
,
"64"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.20
)
# Wait a little bit so that the memory check happens.
time
.
sleep
(
5
)
class
TestEAGLEServerTriton
(
TestEAGLEServer
):
class
TestEAGLEServerTriton
(
TestEAGLEServer
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment