Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
aba9eae4
Unverified
Commit
aba9eae4
authored
Oct 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 11, 2024
Browse files
Fix the correctness test in bench_latency.py when tp > 1 and test_generation_models.py (#1631)
parent
bbd72bfc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
8 deletions
+17
-8
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-2
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+15
-6
No files found.
python/sglang/bench_latency.py
View file @
aba9eae4
...
@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
...
@@ -220,6 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return
reqs
return
reqs
@
torch
.
inference_mode
()
def
extend
(
reqs
,
model_runner
):
def
extend
(
reqs
,
model_runner
):
batch
=
ScheduleBatch
.
init_new
(
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
reqs
=
reqs
,
...
@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
...
@@ -235,6 +236,7 @@ def extend(reqs, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
return
next_token_ids
,
logits_output
.
next_token_logits
,
batch
@
torch
.
inference_mode
()
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
)
batch
.
prepare_for_decode
(
input_token_ids
)
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
...
@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
...
@@ -244,7 +246,6 @@ def decode(input_token_ids, batch, model_runner):
return
next_token_ids
,
logits_output
.
next_token_logits
return
next_token_ids
,
logits_output
.
next_token_logits
@
torch
.
inference_mode
()
def
correctness_test
(
def
correctness_test
(
server_args
,
server_args
,
port_args
,
port_args
,
...
@@ -287,7 +288,6 @@ def correctness_test(
...
@@ -287,7 +288,6 @@ def correctness_test(
rank_print
(
tokenizer
.
decode
(
output_ids
[
i
]),
"
\n
"
)
rank_print
(
tokenizer
.
decode
(
output_ids
[
i
]),
"
\n
"
)
@
torch
.
inference_mode
()
def
latency_test_run_once
(
def
latency_test_run_once
(
run_name
,
model_runner
,
rank_print
,
reqs
,
batch_size
,
input_len
,
output_len
run_name
,
model_runner
,
rank_print
,
reqs
,
batch_size
,
input_len
,
output_len
):
):
...
...
test/srt/models/test_generation_models.py
View file @
aba9eae4
...
@@ -42,13 +42,13 @@ class ModelCase:
...
@@ -42,13 +42,13 @@ class ModelCase:
rouge_l_tolerance
:
float
=
1
rouge_l_tolerance
:
float
=
1
# Popular models that run on CI
# Popular models that run on
the
CI
CI_MODELS
=
[
CI_MODELS
=
[
ModelCase
(
"meta-llama/Llama-3.1-8B-Instruct"
),
ModelCase
(
"meta-llama/Llama-3.1-8B-Instruct"
),
ModelCase
(
"google/gemma-2-2b"
),
ModelCase
(
"google/gemma-2-2b"
),
]
]
# All other models
# All other models
that do not run on the CI
ALL_OTHER_MODELS
=
[
ALL_OTHER_MODELS
=
[
ModelCase
(
"Qwen/Qwen2-1.5B"
),
ModelCase
(
"Qwen/Qwen2-1.5B"
),
ModelCase
(
"Qwen/Qwen2.5-14B-Instruct"
),
ModelCase
(
"Qwen/Qwen2.5-14B-Instruct"
),
...
@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
...
@@ -59,6 +59,10 @@ TORCH_DTYPES = [torch.float16]
class
TestGenerationModels
(
unittest
.
TestCase
):
class
TestGenerationModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
)
def
assert_close_logits_and_output_strs
(
def
assert_close_logits_and_output_strs
(
self
,
self
,
prompts
:
List
[
str
],
prompts
:
List
[
str
],
...
@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -140,16 +144,21 @@ class TestGenerationModels(unittest.TestCase):
return
return
for
model_case
in
ALL_OTHER_MODELS
:
for
model_case
in
ALL_OTHER_MODELS
:
# Only run a specified model
if
(
if
(
"ONLY_RUN"
in
os
.
environ
"ONLY_RUN"
in
os
.
environ
and
os
.
environ
[
"ONLY_RUN"
]
!=
model_case
.
model_path
and
os
.
environ
[
"ONLY_RUN"
]
!=
model_case
.
model_path
):
):
continue
continue
self
.
assert_close_logits_and_output_strs
(
DEFAULT_PROMPTS
,
model_case
,
torch
.
float16
# Skip long prompts for models that does not have a long context
)
prompts
=
DEFAULT_PROMPTS
if
model_case
.
model_path
in
[
"HuggingFaceTB/SmolLM-135M-Instruct"
]:
prompts
=
[
p
for
p
in
DEFAULT_PROMPTS
if
len
(
p
)
<
1000
]
# Assert the logits and output strs are close
self
.
assert_close_logits_and_output_strs
(
prompts
,
model_case
,
torch
.
float16
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
mp
.
set_start_method
(
"spawn"
)
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment