Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2cf0df33
Unverified
Commit
2cf0df33
authored
Jul 24, 2024
by
Nick Hill
Committed by
GitHub
Jul 24, 2024
Browse files
[Bugfix] Fix speculative decode seeded test (#6743)
parent
54514634
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
6 deletions
+19
-6
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+2
-1
tests/spec_decode/e2e/test_seed.py
tests/spec_decode/e2e/test_seed.py
+17
-5
No files found.
tests/spec_decode/e2e/conftest.py
View file @
2cf0df33
...
@@ -191,7 +191,8 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
...
@@ -191,7 +191,8 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
and
llm
.
llm_engine
.
log_stats
):
and
llm
.
llm_engine
.
log_stats
):
for
sate_logger
in
llm
.
llm_engine
.
stat_loggers
.
values
():
for
sate_logger
in
llm
.
llm_engine
.
stat_loggers
.
values
():
sate_logger
.
local_interval
=
0
sate_logger
.
local_interval
=
0
set_random_seed
(
seed
)
if
seed
is
not
None
:
set_random_seed
(
seed
)
yield
llm
yield
llm
del
llm
del
llm
...
...
tests/spec_decode/e2e/test_seed.py
View file @
2cf0df33
...
@@ -21,7 +21,8 @@ from .conftest import run_equality_correctness_test
...
@@ -21,7 +21,8 @@ from .conftest import run_equality_correctness_test
"num_speculative_tokens"
:
3
,
"num_speculative_tokens"
:
3
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"seed"
:
1
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"seed"
:
5
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
8
,
32
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.1
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"temperature"
,
[
0.1
,
1.0
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -30,15 +31,26 @@ from .conftest import run_equality_correctness_test
...
@@ -30,15 +31,26 @@ from .conftest import run_equality_correctness_test
# Use smaller output len for fast test.
# Use smaller output len for fast test.
10
,
10
,
])
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
None
])
def
test_seeded_consistency
(
baseline_llm_generator
,
batch_size
:
int
,
def
test_seeded_consistency
(
baseline_llm_generator
,
test_llm_generator
,
temperature
:
float
,
output_len
:
int
):
batch_size
:
int
,
temperature
:
float
,
output_len
:
int
):
"""Verify outputs are consistent across multiple runs with same seed
"""Verify outputs are consistent across multiple runs with same seed
"""
"""
run_equality_correctness_test
(
baseline_llm_generator
,
run_equality_correctness_test
(
baseline_llm_generator
,
baseline
_llm_generator
,
test
_llm_generator
,
batch_size
,
batch_size
,
max_output_len
=
output_len
,
max_output_len
=
output_len
,
temperature
=
temperature
,
temperature
=
temperature
,
seeded
=
True
,
seeded
=
True
,
force_output_len
=
True
)
force_output_len
=
True
)
# Ensure this same test does fail if we _don't_ include per-request seeds
with
pytest
.
raises
(
AssertionError
):
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
temperature
,
seeded
=
False
,
force_output_len
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment