Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d5a16977
Unverified
Commit
d5a16977
authored
May 25, 2024
by
Lily Liu
Committed by
GitHub
May 25, 2024
Browse files
[Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000)
parent
325c1199
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
11 deletions
+63
-11
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+41
-0
tests/spec_decode/test_dynamic_spec_decode.py
tests/spec_decode/test_dynamic_spec_decode.py
+10
-6
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+12
-5
No files found.
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
d5a16977
...
@@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
...
@@ -170,3 +170,44 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size
,
batch_size
,
max_output_len
=
output_len
,
max_output_len
=
output_len
,
force_output_len
=
True
)
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
"speculative_disable_by_batch_size"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_disable_queue
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
tests/spec_decode/test_dynamic_spec_decode.py
View file @
d5a16977
from
unittest.mock
import
MagicMock
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
pytest
import
torch
import
torch
...
@@ -13,9 +13,9 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
...
@@ -13,9 +13,9 @@ from vllm.spec_decode.top1_proposer import Top1Proposer
from
.utils
import
create_batch
,
mock_worker
from
.utils
import
create_batch
,
mock_worker
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
2
,
4
])
@
pytest
.
mark
.
parametrize
(
'queue_size'
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
3
,
6
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
,
2
,
5
,
7
,
10
])
@
pytest
.
mark
.
parametrize
(
'k'
,
[
1
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
):
def
test_disable_spec_tokens
(
queue_size
:
int
,
batch_size
:
int
,
k
:
int
):
"""Verify that speculative tokens are disabled when the batch size
"""Verify that speculative tokens are disabled when the batch size
...
@@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
...
@@ -42,8 +42,12 @@ def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
num_lookahead_slots
=
k
,
num_lookahead_slots
=
k
,
running_queue_size
=
queue_size
)
running_queue_size
=
queue_size
)
with
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
if
queue_size
>
disable_by_batch_size
:
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
with
patch
.
object
(
worker
,
'_run_no_spec'
,
side_effect
=
ValueError
(
exception_secret
)),
\
pytest
.
raises
(
ValueError
,
match
=
exception_secret
):
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
# When the batch size is larger than the threshold,
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
# we expect no speculative tokens (0).
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
d5a16977
...
@@ -273,10 +273,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -273,10 +273,17 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
_maybe_disable_speculative_tokens
(
self
.
_maybe_disable_speculative_tokens
(
disable_all_speculation
,
execute_model_req
.
seq_group_metadata_list
)
disable_all_speculation
,
execute_model_req
.
seq_group_metadata_list
)
# If no spec tokens, call the proposer and scorer workers normally.
# Speculative decoding is disabled in the following cases:
# Used for prefill.
# 1. Prefill phase: Speculative decoding is not
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# In any of these cases, the proposer and scorer workers
# are called normally.
if
num_lookahead_slots
==
0
or
len
(
if
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
:
execute_model_req
.
seq_group_metadata_list
)
==
0
or
disable_all_speculation
:
return
self
.
_run_no_spec
(
execute_model_req
,
return
self
.
_run_no_spec
(
execute_model_req
,
skip_proposer
=
disable_all_speculation
)
skip_proposer
=
disable_all_speculation
)
...
@@ -316,8 +323,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -316,8 +323,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
def
_run_no_spec
(
self
,
execute_model_req
:
ExecuteModelRequest
,
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
skip_proposer
:
bool
)
->
List
[
SamplerOutput
]:
"""Run a
prefill
step
,
without any speculation. The input is
sent to
"""Run a
single generation
step without any speculation. The input is
the proposer and scorer model so that the KV cache is consistent
sent to
the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
updated, so they cannot enable spec decode in the rest decoding.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment