Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
60abdb3e
"vscode:/vscode.git/clone" did not exist on "ea66e5e5a2a9a7c005e09aaa6c0a97a352db364a"
Unverified
Commit
60abdb3e
authored
Feb 09, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 09, 2025
Browse files
minor: cleanup test_eagle_infer (#3415)
parent
7b4e61ff
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
64 deletions
+63
-64
test/srt/test_eagle_infer.py
test/srt/test_eagle_infer.py
+63
-64
No files found.
test/srt/test_eagle_infer.py
View file @
60abdb3e
...
...
@@ -20,79 +20,78 @@ from sglang.test.test_utils import (
class
TestEAGLEEngine
(
unittest
.
TestCase
):
BASE_CONFIG
=
{
"model_path"
:
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
"speculative_draft_model_path"
:
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"speculative_algorithm"
:
"EAGLE"
,
"speculative_num_steps"
:
5
,
"speculative_eagle_topk"
:
8
,
"speculative_num_draft_tokens"
:
64
,
"mem_fraction_static"
:
0.7
,
}
def
setUp
(
self
):
self
.
prompt
=
"Today is a sunny day and I like"
self
.
sampling_params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
def
test_eagle_accuracy
(
self
):
prompt1
=
"Today is a sunny day and I like"
sampling_params1
=
{
"temperature"
:
0
,
"max_new_tokens"
:
8
}
# Get the reference output
ref_engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
)
ref_output
=
ref_engine
.
generate
(
prompt
1
,
sampling_params
1
)[
"text"
]
self
.
ref_output
=
ref_engine
.
generate
(
self
.
prompt
,
self
.
sampling_params
)[
"text"
]
ref_engine
.
shutdown
()
# Test cases with different configurations
def
test_eagle_accuracy
(
self
):
configs
=
[
# Original config
{
"model_path"
:
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
"speculative_draft_model_path"
:
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"speculative_algorithm"
:
"EAGLE"
,
"speculative_num_steps"
:
5
,
"speculative_eagle_topk"
:
8
,
"speculative_num_draft_tokens"
:
64
,
"mem_fraction_static"
:
0.7
,
},
# Config with CUDA graph disabled
{
"model_path"
:
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
,
"speculative_draft_model_path"
:
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
,
"speculative_algorithm"
:
"EAGLE"
,
"speculative_num_steps"
:
5
,
"speculative_eagle_topk"
:
8
,
"speculative_num_draft_tokens"
:
64
,
"mem_fraction_static"
:
0.7
,
"disable_cuda_graph"
:
True
,
},
self
.
BASE_CONFIG
,
{
**
self
.
BASE_CONFIG
,
"disable_cuda_graph"
:
True
},
]
for
config
in
configs
:
# Launch EAGLE engine
engine
=
sgl
.
Engine
(
**
config
)
# Case 1: Test the output of EAGLE engine is the same as normal engine
out1
=
engine
.
generate
(
prompt1
,
sampling_params1
)[
"text"
]
print
(
f
"
{
out1
=
}
,
{
ref_output
=
}
"
)
self
.
assertEqual
(
out1
,
ref_output
)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
prompt2
=
"[INST] <<SYS>>
\\
nYou are a helpful assistant.
\\
n<</SYS>>
\\
nToday is a sunny day and I like [/INST]"
sampling_params2
=
{
"temperature"
:
0
,
"max_new_tokens"
:
1024
,
"skip_special_tokens"
:
False
,
}
with
self
.
subTest
(
cuda_graph
=
(
"enabled"
if
len
(
config
)
==
len
(
self
.
BASE_CONFIG
)
else
"disabled"
)
):
engine
=
sgl
.
Engine
(
**
config
)
try
:
self
.
_test_basic_generation
(
engine
)
self
.
_test_eos_token
(
engine
)
self
.
_test_batch_generation
(
engine
)
finally
:
engine
.
shutdown
()
def
_test_basic_generation
(
self
,
engine
):
output
=
engine
.
generate
(
self
.
prompt
,
self
.
sampling_params
)[
"text"
]
print
(
f
"
{
output
=
}
,
{
self
.
ref_output
=
}
"
)
self
.
assertEqual
(
output
,
self
.
ref_output
)
def
_test_eos_token
(
self
,
engine
):
prompt
=
"[INST] <<SYS>>
\n
You are a helpful assistant.
\n
<</SYS>>
\n
Today is a sunny day and I like [/INST]"
params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
1024
,
"skip_special_tokens"
:
False
,
}
tokenizer
=
get_tokenizer
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
)
output
=
engine
.
generate
(
prompt
,
params
)[
"text"
]
print
(
f
"
{
output
=
}
"
)
tokens
=
tokenizer
.
encode
(
output
,
truncation
=
False
)
self
.
assertNotIn
(
tokenizer
.
eos_token_id
,
tokens
)
def
_test_batch_generation
(
self
,
engine
):
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
params
=
{
"temperature"
:
0
,
"max_new_tokens"
:
30
}
tokenizer
=
get_tokenizer
(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
)
out2
=
engine
.
generate
(
prompt2
,
sampling_params2
)[
"text"
]
print
(
f
"
{
out2
=
}
"
)
tokens
=
tokenizer
.
encode
(
out2
,
truncation
=
False
)
assert
tokenizer
.
eos_token_id
not
in
tokens
# Case 3: Batched prompts
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params3
=
{
"temperature"
:
0
,
"max_new_tokens"
:
30
}
outputs
=
engine
.
generate
(
prompts
,
sampling_params3
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
print
(
"==============================="
)
print
(
f
"Prompt:
{
prompt
}
\n
Generated text:
{
output
[
'text'
]
}
"
)
# Shutdown the engine
engine
.
shutdown
()
outputs
=
engine
.
generate
(
prompts
,
params
)
for
prompt
,
output
in
zip
(
prompts
,
outputs
):
print
(
f
"Prompt:
{
prompt
}
"
)
print
(
f
"Generated:
{
output
[
'text'
]
}
"
)
print
(
"-"
*
40
)
prompts
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment