Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a921e863
Unverified
Commit
a921e863
authored
Jul 19, 2024
by
Woo-Yeon Lee
Committed by
GitHub
Jul 19, 2024
Browse files
[BUGFIX] Raise an error for no draft token case when draft_tp>1 (#6369)
parent
6366efc6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
5 deletions
+85
-5
tests/spec_decode/e2e/test_integration_dist_tp4.py
tests/spec_decode/e2e/test_integration_dist_tp4.py
+62
-0
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+3
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+19
-4
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+1
-1
No files found.
tests/spec_decode/e2e/test_integration_dist_tp4.py
View file @
a921e863
...
@@ -58,3 +58,65 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
...
@@ -58,3 +58,65 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
batch_size
,
batch_size
,
max_output_len
=
32
,
max_output_len
=
32
,
force_output_len
=
True
)
force_output_len
=
True
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Need at least 4 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-160m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"tensor_parallel_size"
:
4
,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
5
,
# Artificially limit the draft model max model len; this forces vLLM
# to skip speculation once the sequences grow beyond 32-k tokens.
"speculative_max_model_len"
:
32
,
},
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# This must be a good bit larger than speculative_max_model_len so that
# we can test the case where all seqs are skipped, but still small to
# ensure fast test.
64
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_skip_speculation
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify job failure with RuntimeError when all sequences skip speculation.
We do this by setting the max model len of the draft model to an
artificially low value, such that when the sequences grow beyond it, they
are skipped in speculative decoding.
TODO: fix it to pass without raising Error. (#5814)
"""
with
pytest
.
raises
(
RuntimeError
):
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
vllm/spec_decode/interfaces.py
View file @
a921e863
...
@@ -22,6 +22,9 @@ class SpeculativeProposals:
...
@@ -22,6 +22,9 @@ class SpeculativeProposals:
# The valid length of each proposal; can be zero.
# The valid length of each proposal; can be zero.
proposal_lens
:
torch
.
Tensor
proposal_lens
:
torch
.
Tensor
# A flag to mark that there's no available proposals
no_proposals
:
bool
=
False
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
f
"SpeculativeProposals("
return
(
f
"SpeculativeProposals("
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
f
"proposal_token_ids=
{
self
.
proposal_token_ids
}
, "
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
a921e863
...
@@ -109,6 +109,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -109,6 +109,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_alpha
:
float
,
typical_acceptance_sampler_posterior_alpha
:
float
,
)
->
"SpecDecodeWorker"
:
)
->
"SpecDecodeWorker"
:
allow_zero_draft_token_step
=
True
ngram_prompt_lookup_max
=
(
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
ngram_prompt_lookup_min
=
(
...
@@ -133,6 +134,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -133,6 +134,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if
draft_tp
==
1
:
if
draft_tp
==
1
:
draft_worker_kwargs
[
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
allow_zero_draft_token_step
=
False
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
...
@@ -155,10 +158,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -155,10 +158,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"Configuring SpecDecodeWorker with sampler=%s"
,
logger
.
info
(
"Configuring SpecDecodeWorker with sampler=%s"
,
type
(
spec_decode_sampler
))
type
(
spec_decode_sampler
))
return
SpecDecodeWorker
(
proposer_worker
,
return
SpecDecodeWorker
(
scorer_worker
,
proposer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
scorer_worker
,
spec_decode_sampler
=
spec_decode_sampler
)
disable_by_batch_size
=
disable_by_batch_size
,
spec_decode_sampler
=
spec_decode_sampler
,
allow_zero_draft_token_step
=
allow_zero_draft_token_step
)
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -167,6 +172,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -167,6 +172,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
spec_decode_sampler
:
SpecDecodeBaseSampler
,
spec_decode_sampler
:
SpecDecodeBaseSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
disable_by_batch_size
:
Optional
[
int
]
=
None
,
allow_zero_draft_token_step
:
Optional
[
bool
]
=
True
,
):
):
"""
"""
Create a SpecDecodeWorker.
Create a SpecDecodeWorker.
...
@@ -187,11 +193,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -187,11 +193,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable speculative decoding for new incoming requests.
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
for testing purposes.
allow_zero_draft_token_step: whether to allow a step where the draft
model generates no draft token; should disallow when the tp of
draft model is larger than 1 (TODO: #5814)
"""
"""
self
.
proposer_worker
=
proposer_worker
self
.
proposer_worker
=
proposer_worker
self
.
scorer_worker
=
scorer_worker
self
.
scorer_worker
=
scorer_worker
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
disable_by_batch_size
=
disable_by_batch_size
or
float
(
"inf"
)
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
spec_decode_sampler
=
spec_decode_sampler
self
.
_allow_zero_draft_token_step
=
allow_zero_draft_token_step
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
_metrics
=
AsyncMetricsCollector
(
self
.
spec_decode_sampler
self
.
spec_decode_sampler
)
if
metrics_collector
is
None
else
metrics_collector
)
if
metrics_collector
is
None
else
metrics_collector
...
@@ -461,6 +471,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -461,6 +471,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
execute_model_req
,
self
.
_seq_with_bonus_token_in_last_step
)
if
not
self
.
_allow_zero_draft_token_step
and
proposals
.
no_proposals
:
#TODO: Fix it #5814
raise
RuntimeError
(
"Cannot handle cases where distributed draft "
"workers generate no tokens"
)
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
execute_model_req
,
execute_model_req
,
proposals
,
proposals
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
a921e863
...
@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -108,7 +108,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_token_ids
=
proposal_tokens
,
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
proposal_lens
=
proposal_lens
,
)
no_proposals
=
maybe_sampler_output
is
None
)
return
proposals
return
proposals
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment