Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b38e42fb
Unverified
Commit
b38e42fb
authored
May 02, 2024
by
leiwen83
Committed by
GitHub
May 01, 2024
Browse files
[Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by:
Lei Wen
<
wenlei03@qiyi.com
>
parent
8b798eec
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1004 additions
and
319 deletions
+1004
-319
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+58
-0
tests/spec_decode/e2e/test_multistep_correctness.py
tests/spec_decode/e2e/test_multistep_correctness.py
+2
-58
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+172
-0
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+25
-25
tests/spec_decode/test_ngram_worker.py
tests/spec_decode/test_ngram_worker.py
+206
-0
vllm/config.py
vllm/config.py
+59
-28
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+18
-0
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+4
-4
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+2
-2
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+24
-185
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+190
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+32
-13
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+200
-0
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+12
-4
No files found.
tests/spec_decode/e2e/conftest.py
View file @
b38e42fb
import
asyncio
from
itertools
import
cycle
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
pytest
...
...
@@ -185,3 +186,60 @@ def get_output_from_llm_generator(
del
llm
return
tokens
,
token_ids
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
)
spec_batch_tokens
,
spec_batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
for
i
,
(
baseline_token_ids
,
baseline_tokens
,
spec_token_ids
,
spec_tokens
)
in
enumerate
(
zip
(
baseline_batch_token_ids
,
baseline_batch_tokens
,
spec_batch_token_ids
,
spec_batch_tokens
)):
if
print_tokens
:
print
(
f
'
{
i
=
}
{
baseline_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
tests/spec_decode/e2e/test_correctness.py
→
tests/spec_decode/e2e/test_
multistep_
correctness.py
View file @
b38e42fb
...
...
@@ -35,7 +35,8 @@ from transformers import AutoTokenizer
from
vllm
import
SamplingParams
from
.conftest
import
get_output_from_llm_generator
from
.conftest
import
(
get_output_from_llm_generator
,
run_greedy_equality_correctness_test
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
def
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
,
force_output_len
:
bool
,
print_tokens
:
bool
=
False
):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
"San Francisco is know for its"
,
"Facebook was created in 2004 by"
,
"Curious George is a"
,
"Python 3.11 brings improvements to its"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos
=
force_output_len
sampling_params
=
SamplingParams
(
max_tokens
=
max_output_len
,
ignore_eos
=
ignore_eos
,
temperature
=
temperature
,
)
spec_batch_tokens
,
spec_batch_token_ids
=
get_output_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
(
baseline_batch_tokens
,
baseline_batch_token_ids
)
=
get_output_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
assert
len
(
baseline_batch_token_ids
)
==
len
(
prompts
)
assert
len
(
spec_batch_token_ids
)
==
len
(
prompts
)
for
i
,
(
baseline_token_ids
,
baseline_tokens
,
spec_token_ids
,
spec_tokens
)
in
enumerate
(
zip
(
baseline_batch_token_ids
,
baseline_batch_tokens
,
spec_batch_token_ids
,
spec_batch_tokens
)):
if
print_tokens
:
print
(
f
'
{
i
=
}
{
baseline_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_tokens
=
}
'
)
print
(
f
'
{
i
=
}
{
baseline_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
tests/spec_decode/e2e/test_ngram_correctness.py
0 → 100644
View file @
b38e42fb
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""
import
pytest
from
.conftest
import
run_greedy_equality_correctness_test
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-68m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
64
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"block_size"
:
8
,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override"
:
2
+
256
//
8
,
"max_model_len"
:
(
2
+
256
//
8
)
*
8
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"model"
:
"JackFram/llama-160m"
,
},
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use small output len for fast test.
256
,
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_e2e_greedy_correctness_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
k
,
"ngram_prompt_lookup_max"
:
3
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
]
+
[
{
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
k
,
"ngram_prompt_lookup_max"
:
1
,
}
# Try a range of common k, as well as large speculation.
for
k
in
[
1
,
3
,
5
]
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_different_k
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
force_output_len
=
True
)
tests/spec_decode/test_multi_step_worker.py
View file @
b38e42fb
...
...
@@ -6,8 +6,8 @@ import torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplerOutput
from
vllm.spec_decode.multi_step_worker
import
(
DraftModelTop1Propos
er
,
MultiStepWork
er
)
from
vllm.spec_decode.multi_step_worker
import
MultiStepWork
er
from
vllm.spec_decode.top1_proposer
import
Top1Propos
er
from
vllm.worker.worker
import
Worker
from
.utils
import
(
assert_logprobs_dict_allclose
,
create_batch
,
...
...
@@ -117,8 +117,8 @@ def test_same_output_for_single_step():
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
actual_output
=
multi_step_worker
.
execute_model_multi_step
(
**
multi_step_execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
actual_output
,
_
=
multi_step_worker
.
sampler_output
(
**
multi_step_execute_model_data
.
to_dict
(),
sample_len
=
num_steps
)
assert
len
(
actual_output
)
==
num_steps
actual_output
=
actual_output
[
0
]
...
...
@@ -200,8 +200,8 @@ def test_same_output_for_multi_step():
# Run multi-step.
zero_kv_cache
(
multi_step_worker
.
cache_engine
)
set_random_seed
(
seed
)
multi_step_output
=
multi_step_worker
.
execute_model_multi_step
(
**
execute_model_data
.
to_dict
(),
num_steps
=
num_steps
)
multi_step_output
,
_
=
multi_step_worker
.
sampler_output
(
**
execute_model_data
.
to_dict
(),
sample_len
=
num_steps
)
# Run single-step repeatedly.
zero_kv_cache
(
worker
.
cache_engine
)
...
...
@@ -266,7 +266,7 @@ def test_same_output_for_multi_step():
@
torch
.
inference_mode
()
def
test_draft_proposals_full_speculation_len
():
"""Verify
DraftModel
Top1Proposer correctly handles case where all sequences
"""Verify Top1Proposer correctly handles case where all sequences
can speculate.
"""
k
=
10
...
...
@@ -275,13 +275,13 @@ def test_draft_proposals_full_speculation_len():
device
=
'cuda:0'
draft_worker
=
MagicMock
()
proposer
=
DraftModel
Top1Proposer
(
draft_
worker
=
draft_worker
,
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
device
,
max_model_len
=
2048
,
vocab_size
=
vocab_size
,
max_proposal_len
=
2048
,
)
draft_worker
.
execute_model_multi_step
.
return_value
=
[
draft_worker
.
sampler_output
.
return_value
=
[
SamplerOutput
(
outputs
=
[],
sampled_token_probs
=
torch
.
rand
(
batch_size
,
...
...
@@ -294,13 +294,13 @@ def test_draft_proposals_full_speculation_len():
device
=
device
,
dtype
=
torch
.
long
),
)
for
_
in
range
(
k
)
]
]
,
True
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
k
)
proposals
=
proposer
.
get_proposals
(
**
execute_model_data
.
to_dict
(),
max_
proposal_len
=
k
,
proposal_len
=
k
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
...
...
@@ -315,7 +315,7 @@ def test_draft_proposals_full_speculation_len():
@
torch
.
inference_mode
()
def
test_draft_proposals_no_speculations
():
"""Verify
DraftModel
Top1Proposer correctly handles case where no sequences
"""Verify Top1Proposer correctly handles case where no sequences
can speculate.
"""
k
=
10
...
...
@@ -325,11 +325,11 @@ def test_draft_proposals_no_speculations():
prompt_len
=
10
draft_worker
=
MagicMock
()
proposer
=
DraftModel
Top1Proposer
(
draft_
worker
=
draft_worker
,
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
device
,
max_model_len
=
prompt_len
+
k
-
1
,
vocab_size
=
vocab_size
,
max_proposal_len
=
prompt_len
+
k
-
1
,
)
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
...
...
@@ -338,7 +338,7 @@ def test_draft_proposals_no_speculations():
proposals
=
proposer
.
get_proposals
(
**
execute_model_data
.
to_dict
(),
max_
proposal_len
=
k
,
proposal_len
=
k
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
...
...
@@ -353,7 +353,7 @@ def test_draft_proposals_no_speculations():
@
torch
.
inference_mode
()
def
test_draft_proposals_mixed_k
():
"""Verify
DraftModel
Top1Proposer correctly handles case some sequences can
"""Verify Top1Proposer correctly handles case some sequences can
speculate and some can't.
"""
k
=
10
...
...
@@ -374,14 +374,14 @@ def test_draft_proposals_mixed_k():
for
_
in
range
(
expected_num_no_proposal_seqs
)]
+
[
small_prompt_len
]
draft_worker
=
MagicMock
()
proposer
=
DraftModel
Top1Proposer
(
draft_
worker
=
draft_worker
,
proposer
=
Top1Proposer
(
worker
=
draft_worker
,
device
=
device
,
max_model_len
=
long_prompt_len
+
prev_output_token_len
+
k
-
1
,
vocab_size
=
vocab_size
,
max_proposal_len
=
long_prompt_len
+
prev_output_token_len
+
k
-
1
,
)
draft_worker
.
execute_model_multi_step
.
return_value
=
[
draft_worker
.
sampler_output
.
return_value
=
[
SamplerOutput
(
outputs
=
[],
sampled_token_probs
=
torch
.
rand
(
expected_num_proposal_seqs
,
...
...
@@ -395,7 +395,7 @@ def test_draft_proposals_mixed_k():
device
=
device
,
dtype
=
torch
.
long
),
)
for
_
in
range
(
k
)
]
]
,
True
execute_model_data
,
_
,
_
=
create_batch
(
batch_size
,
...
...
@@ -406,7 +406,7 @@ def test_draft_proposals_mixed_k():
proposals
=
proposer
.
get_proposals
(
**
execute_model_data
.
to_dict
(),
max_
proposal_len
=
k
,
proposal_len
=
k
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
...
...
tests/spec_decode/test_ngram_worker.py
0 → 100644
View file @
b38e42fb
import
torch
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
.utils
import
(
create_execute_model_data
,
create_seq_group_metadata_from_prompts
,
create_worker
)
def
test_ngram_algo_correctness_for_single_no_match
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario cannot find any candidate in one single batch
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
prompts
=
[
# shall find no candidate
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
]
proposal_len
=
5
final_seq_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
proposal_len
=
proposal_len
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
1
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
1
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
1
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
0
]
def
test_ngram_algo_correctness_for_batches_not_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find some candidate not full in batchs
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
prompts
=
[
# shall find no candidate
[
1
,
2
,
3
,
4
,
5
,
6
,
7
],
# shall find candidate 12,13,14,15,16
[
11
,
12
,
13
,
14
,
15
,
16
,
11
],
# shall find candidate 23,24,25,26,21
[
21
,
21
,
22
,
23
,
24
,
25
,
26
,
21
,
22
],
# shall find candidate 34,35,36,37,38
[
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
# shall find no candidate as exceed max_proposal_len
[
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
]
proposal_len
=
5
final_seq_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
proposal_len
=
proposal_len
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
5
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
5
])
assert
proposals
.
proposal_lens
.
tolist
(
)
==
[
proposal_len
for
_
in
range
(
4
)]
+
[
0
]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
0
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
3
][
i
]
==
prompts
[
3
][
i
+
5
]
assert
proposals
.
proposal_token_ids
[
4
][
i
]
==
-
1
def
test_ngram_algo_correctness_for_batches_match_all
():
"""Verify our ngram algo find the right candidate in the prompt
For the scenario find candidate in all batchs
"""
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
seed
=
100
model_name
=
'JackFram/llama-68m'
vocab_size
=
32_000
device
=
'cuda:0'
ngram_worker
=
create_worker
(
NGramWorker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
,
)
proposer
=
Top1Proposer
(
worker
=
ngram_worker
,
device
=
device
,
vocab_size
=
vocab_size
,
max_proposal_len
=
20
,
)
# set ngram window (0, 3], which is window=1/2/3
ngram_worker
.
set_ngram_window_size
(
0
,
3
)
prompts
=
[
# shall find candidate 12,13,14,15,16
[
11
,
12
,
13
,
14
,
15
,
16
,
11
],
# shall find candidate 23,24,25,26,21
[
21
,
21
,
22
,
23
,
24
,
25
,
26
,
21
,
22
],
# shall find candidate 34,35,36,37,38
[
31
,
32
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
31
,
32
,
33
],
]
proposal_len
=
5
final_seq_lens
=
[
len
(
prompt
)
+
proposal_len
for
prompt
in
prompts
]
ngram_sampler_output_data
=
create_execute_model_data
(
seq_group_metadata_list
=
create_seq_group_metadata_from_prompts
(
prompts
,
num_gpu_blocks
,
block_size
,
final_seq_lens
=
final_seq_lens
))
proposals
=
proposer
.
get_proposals
(
**
ngram_sampler_output_data
.
to_dict
(),
proposal_len
=
proposal_len
,
)
assert
torch
.
is_tensor
(
proposals
.
proposal_token_ids
)
assert
torch
.
is_tensor
(
proposals
.
proposal_probs
)
assert
proposals
.
proposal_token_ids
.
shape
==
torch
.
Size
([
3
,
proposal_len
])
assert
proposals
.
proposal_probs
.
shape
[:
-
1
]
==
torch
.
Size
([
3
,
proposal_len
])
assert
proposals
.
proposal_lens
.
shape
==
torch
.
Size
([
3
])
assert
proposals
.
proposal_lens
.
tolist
()
==
[
proposal_len
for
_
in
range
(
3
)]
for
i
in
range
(
proposal_len
):
assert
proposals
.
proposal_token_ids
[
0
][
i
]
==
prompts
[
0
][
i
+
1
]
assert
proposals
.
proposal_token_ids
[
1
][
i
]
==
prompts
[
1
][
i
+
3
]
assert
proposals
.
proposal_token_ids
[
2
][
i
]
==
prompts
[
2
][
i
+
5
]
vllm/config.py
View file @
b38e42fb
...
...
@@ -682,6 +682,8 @@ class SpeculativeConfig:
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
)
->
Optional
[
"SpeculativeConfig"
]:
"""Create a SpeculativeConfig if possible, else return None.
...
...
@@ -708,6 +710,10 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...
...
@@ -742,6 +748,22 @@ class SpeculativeConfig:
draft_code_revision
=
None
draft_quantization
=
None
if
speculative_model
==
"[ngram]"
:
assert
(
ngram_prompt_lookup_max
is
not
None
and
ngram_prompt_lookup_max
>
0
)
if
ngram_prompt_lookup_min
is
None
:
ngram_prompt_lookup_min
=
0
else
:
assert
ngram_prompt_lookup_max
>
ngram_prompt_lookup_min
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
draft_model_config
=
target_model_config
draft_parallel_config
=
target_parallel_config
else
:
ngram_prompt_lookup_max
=
0
ngram_prompt_lookup_min
=
0
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
tokenizer
=
target_model_config
.
tokenizer
,
...
...
@@ -775,6 +797,8 @@ class SpeculativeConfig:
draft_model_config
,
draft_parallel_config
,
num_speculative_tokens
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
)
@
staticmethod
...
...
@@ -842,6 +866,8 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
ngram_prompt_lookup_max
:
int
,
ngram_prompt_lookup_min
:
int
,
):
"""Create a SpeculativeConfig object.
...
...
@@ -854,6 +880,8 @@ class SpeculativeConfig:
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
self
.
_verify_args
()
...
...
@@ -877,6 +905,9 @@ class SpeculativeConfig:
return
self
.
num_speculative_tokens
def
__repr__
(
self
)
->
str
:
if
self
.
ngram_prompt_lookup_max
>
0
:
draft_model
=
"[ngram]"
else
:
draft_model
=
self
.
draft_model_config
.
model
num_spec_tokens
=
self
.
num_speculative_tokens
return
f
"SpeculativeConfig(
{
draft_model
=
}
,
{
num_spec_tokens
=
}
)"
...
...
vllm/engine/arg_utils.py
View file @
b38e42fb
...
...
@@ -75,6 +75,8 @@ class EngineArgs:
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
...
...
@@ -449,6 +451,20 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'speculation.'
)
parser
.
add_argument
(
'--ngram-prompt-lookup-max'
,
type
=
int
,
default
=
EngineArgs
.
ngram_prompt_lookup_max
,
help
=
'Max size of window for ngram prompt lookup in speculative '
'decoding.'
)
parser
.
add_argument
(
'--ngram-prompt-lookup-min'
,
type
=
int
,
default
=
EngineArgs
.
ngram_prompt_lookup_min
,
help
=
'Min size of window for ngram prompt lookup in speculative '
'decoding.'
)
parser
.
add_argument
(
'--model-loader-extra-config'
,
type
=
str
,
default
=
EngineArgs
.
model_loader_extra_config
,
...
...
@@ -502,6 +518,8 @@ class EngineArgs:
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
)
scheduler_config
=
SchedulerConfig
(
...
...
vllm/executor/gpu_executor.py
View file @
b38e42fb
...
...
@@ -73,7 +73,6 @@ class GPUExecutor(ExecutorBase):
"""
assert
self
.
speculative_config
is
not
None
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.spec_decode_worker
import
SpecDecodeWorker
target_worker
=
self
.
_create_worker
()
...
...
@@ -86,10 +85,11 @@ class GPUExecutor(ExecutorBase):
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
draft_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
spec_decode_worker
=
SpecDecodeWorker
.
from_workers
(
proposer_worker
=
draft_worker
,
scorer_worker
=
target_worker
)
spec_decode_worker
=
SpecDecodeWorker
.
create_worker
(
scorer_worker
=
target_worker
,
draft_worker_kwargs
=
draft_worker_kwargs
,
)
assert
self
.
parallel_config
.
world_size
==
1
,
(
"GPUExecutor only supports single GPU."
)
...
...
vllm/spec_decode/batch_expansion.py
View file @
b38e42fb
...
...
@@ -333,13 +333,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
target_token_ids
,
target_probs
=
sampler_output_to_torch
(
[
sampler_output
])
[
sampler_output
]
,
True
)
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
non_spec_target_token_ids
,
non_spec_target_probs
=
(
sampler_output_to_torch
([
sampler_output
]))
sampler_output_to_torch
([
sampler_output
]
,
True
))
return
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
...
...
vllm/spec_decode/multi_step_worker.py
View file @
b38e42fb
import
copy
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
...
...
@@ -26,29 +25,37 @@ class MultiStepWorker(Worker):
super
().
__init__
(
*
args
,
**
kwargs
)
# Lazy initialization list.
self
.
_proposer
:
DraftModel
Top1Proposer
self
.
_proposer
:
Top1Proposer
def
init_device
(
self
):
super
().
init_device
()
self
.
_proposer
=
DraftModel
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
self
,
self
.
device
,
self
.
max_model_len
,
self
.
vocab_size
,
max_proposal_len
=
self
.
max_model_len
,
)
def
set_include_gpu_probs_tensor
(
self
):
# Need include_gpu_probs_tensor for multi_step_worker
self
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
@
torch
.
inference_mode
()
def
execute_model_multi_step
(
def
sampler_output
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_steps
:
int
,
)
->
List
[
SamplerOutput
]:
"""Run the model forward pass num_steps times. Returns the list of
sampler output, one per model forward pass.
sample_len
:
int
,
)
->
Tuple
[
List
[
SamplerOutput
],
bool
]:
"""Run the model forward pass sample_len times. Returns the list of
sampler output, one per model forward pass, along with indicator of
whether torch tensor in sampler output need to be transposed in latter
sampler_output_to_torch logic.
For multi step worker, this indicator shall be True.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
)
...
...
@@ -58,12 +65,12 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list
=
self
.
_shallow_copy_inputs
(
seq_group_metadata_list
)
# Assert enough KV space for
num_steps
tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
num_steps
)
# Assert enough KV space for
sample_len
tokens per sequence.
self
.
_assert_enough_kv_space
(
seq_group_metadata_list
,
sample_len
)
# Run model
num_steps
times.
# Run model
sample_len
times.
model_outputs
=
[]
for
_
in
range
(
num_steps
):
for
_
in
range
(
sample_len
):
model_output
=
super
().
execute_model
(
seq_group_metadata_list
=
copied_seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
...
@@ -78,7 +85,7 @@ class MultiStepWorker(Worker):
copied_seq_group_metadata_list
)
model_outputs
.
append
(
model_output
)
return
model_outputs
return
model_outputs
,
True
def
get_spec_proposals
(
self
,
...
...
@@ -206,171 +213,3 @@ class MultiStepWorker(Worker):
for
seq_group_metadata
in
seq_group_metadata_list
):
raise
NotImplementedError
(
"MultiStepWorker does not support beam search."
)
class
DraftModelTop1Proposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
draft_worker
:
MultiStepWorker
,
device
:
str
,
max_model_len
:
int
,
vocab_size
:
int
,
):
self
.
_draft_worker
=
draft_worker
self
.
_device
=
device
self
.
_max_model_len
=
max_model_len
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
)
=
self
.
_split_by_max_model_len
(
seq_group_metadata_list
,
max_proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
maybe_sampler_output
=
self
.
_draft_worker
.
execute_model_multi_step
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
num_steps
=
max_proposal_len
,
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
max_proposal_len
=
max_proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
)
return
proposals
def
_split_by_max_model_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
max_proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Determine which sequences would exceed the max model length.
"""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
if
seq_len
+
max_proposal_len
<
self
.
_max_model_len
:
proposal_lens
.
append
(
max_proposal_len
)
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
else
:
proposal_lens
.
append
(
0
)
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
max_proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
max_proposal_len
,
),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
max_proposal_len
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
sampler_output
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
max_proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
vllm/spec_decode/ngram_worker.py
0 → 100644
View file @
b38e42fb
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
class
NGramWorker
(
LoraNotSupportedWorkerBase
):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implement prompt lookup decoding,
and in future we may also do RAG type drafter and other scenerios
which don't rely on LLM model to give proposals.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Get local_rank/vocab_size from kwargs attribute
self
.
local_rank
=
kwargs
[
"local_rank"
]
self
.
vocab_size
=
kwargs
[
"model_config"
].
get_vocab_size
()
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
def
set_ngram_window_size
(
self
,
ngram_prompt_lookup_min
:
int
,
ngram_prompt_lookup_max
:
int
):
# Search valid candidate window between
# ngram_prompt_lookup_min/ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
def
init_device
(
self
):
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
# Current only support Top1Proposer
self
.
_proposer
=
Top1Proposer
(
self
,
device
=
self
.
device
,
vocab_size
=
self
.
vocab_size
,
)
def
set_include_gpu_probs_tensor
(
self
):
# NGram don't need gpu sampler
pass
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
)
->
None
:
"""NGram doesn't depend on model execution, just pass this function"""
pass
def
determine_num_available_blocks
(
self
)
->
None
:
"""NGram doesn't depend on model execution, no need to check blocks"""
pass
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
"""As there is no cache need to handle, just pass this function"""
pass
def
get_cache_block_size_bytes
(
self
):
"""Return the size of a cache block in bytes."""
return
0
def
sampler_output
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
sample_len
:
int
,
)
->
Tuple
[
Optional
[
List
[
SamplerOutput
]],
bool
]:
"""NGram match algo to pick proposal candidate. Returns the list of
sampler output, one per SequenceGroupMetadata.
For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False.
"""
self
.
_raise_if_unsupported
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
)
arr
=
[]
has_spec_out
=
False
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
input_ids
=
torch
.
as_tensor
(
seq_data
.
get_token_ids
(),
dtype
=
torch
.
long
,
device
=
self
.
device
)
input_length
=
seq_data
.
get_len
()
for
ngram_size
in
range
(
min
(
self
.
ngram_prompt_lookup_max
,
input_length
-
1
),
self
.
ngram_prompt_lookup_min
,
-
1
,
):
ngram_tensor
=
input_ids
[
-
1
*
ngram_size
:]
windows
=
input_ids
.
unfold
(
dimension
=
0
,
size
=
ngram_size
,
step
=
1
)
matches
=
(
windows
==
ngram_tensor
).
all
(
dim
=
1
)
match_indices
=
matches
.
nonzero
(
as_tuple
=
True
)[
0
]
if
match_indices
.
size
()[
0
]
>
1
:
has_spec_out
=
True
res
=
seq_data
.
get_token_ids
()
res
=
res
[
match_indices
[
0
]
+
ngram_size
:
match_indices
[
0
]
+
ngram_size
+
sample_len
]
res_len
=
len
(
res
)
# pad 0 towards output as sample_len tokens required
res
+=
[
0
]
*
(
sample_len
-
res_len
)
break
else
:
# if no candidate found, fill with 0
res
=
[
0
]
*
sample_len
arr
.
append
(
res
)
if
not
has_spec_out
:
return
None
,
False
outputs
=
[]
token_ids
=
torch
.
as_tensor
(
arr
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
indices
=
token_ids
.
unsqueeze
(
2
)
token_probs
=
torch
.
zeros
(
(
len
(
seq_group_metadata_list
),
sample_len
,
self
.
vocab_size
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
)
token_probs
.
scatter_
(
2
,
indices
,
1
)
for
i
in
range
(
len
(
seq_group_metadata_list
)):
outputs
.
append
(
SamplerOutput
(
outputs
=
None
,
sampled_token_probs
=
token_probs
[
i
],
sampled_token_ids
=
token_ids
[
i
],
))
return
outputs
,
False
def
get_spec_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
max_proposal_len
:
int
,
)
->
SpeculativeProposals
:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return
self
.
_proposer
.
get_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
max_proposal_len
,
)
def
_raise_if_unsupported
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
"""NGramWorker does not yet implement support for cache swap
operations or beam search.
"""
if
any
([
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]):
raise
NotImplementedError
(
"NGramWorker does not support cache operations"
)
if
any
(
len
(
seq_group_metadata
.
seq_data
.
keys
())
!=
1
for
seq_group_metadata
in
seq_group_metadata_list
):
raise
NotImplementedError
(
"NGramWorker does not support beam search."
)
vllm/spec_decode/spec_decode_worker.py
View file @
b38e42fb
...
...
@@ -12,6 +12,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.ngram_worker
import
NGramWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
...
...
@@ -48,8 +49,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
@
classmethod
def
from_workers
(
cls
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
WorkerBase
)
->
"SpecDecodeWorker"
:
def
create_worker
(
cls
,
scorer_worker
:
WorkerBase
,
draft_worker_kwargs
,
)
->
"SpecDecodeWorker"
:
if
"ngram_prompt_lookup_max"
in
draft_worker_kwargs
:
ngram_prompt_lookup_max
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_max"
))
ngram_prompt_lookup_min
=
(
draft_worker_kwargs
.
pop
(
"ngram_prompt_lookup_min"
))
else
:
ngram_prompt_lookup_max
=
0
if
ngram_prompt_lookup_max
>
0
:
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
else
:
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
...
...
@@ -59,7 +79,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def
__init__
(
self
,
proposer_worker
:
MultiStep
Worker
,
proposer_worker
:
Worker
Base
,
scorer_worker
:
WorkerBase
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
...
...
@@ -134,8 +154,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(
self
.
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
(
self
.
proposer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
)
=
True
self
.
proposer_worker
.
set_include_gpu_probs_tensor
()
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""Determine the number of cache blocks to use.
...
...
@@ -183,8 +202,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
logger
.
info
(
"spec_decode_worker.execute_model num_lookahead_slots=%d"
,
num_lookahead_slots
)
#
logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
#
num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
...
...
@@ -216,7 +235,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer and scorer model so that the KV cache is consistent between the
two.
"""
logger
.
info
(
"run proposer worker no spec"
)
#
logger.info("run proposer worker no spec")
self
.
proposer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
...
...
@@ -225,7 +244,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_copy
=
blocks_to_copy
,
)
logger
.
info
(
"run target worker no spec"
)
#
logger.info("run target worker no spec")
sampler_output
=
self
.
scorer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
...
@@ -259,7 +278,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
logger
.
info
(
"get spec proposals"
)
#
logger.info("get spec proposals")
# Generate proposals using draft worker.
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
...
...
@@ -268,7 +287,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
logger
.
info
(
"score proposals"
)
#
logger.info("score proposals")
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
...
...
@@ -278,11 +297,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
)
logger
.
info
(
"verify proposals"
)
#
logger.info("verify proposals")
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
logger
.
info
(
"create output list"
)
#
logger.info("create output list")
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
k
)
...
...
vllm/spec_decode/top1_proposer.py
0 → 100644
View file @
b38e42fb
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.worker.worker_base
import
WorkerBase
class
Top1Proposer
(
SpeculativeProposer
):
"""Helper class which separates out sequences which would exceed the max
model length when speculated upon.
This allows combinations of models such as JackFram/llama-68m draft with
meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
2048 while Llama2-13b has max_position_embeddings of 4096.
We treat the sequences which exceed the proposal draft model length as
"non-spec sequences". Essentially they skip the draft model and go through
normal decoding in the target model.
Currently, only proposal_lens of 0 and k are supported, where k is a global
batch proposal length. In the future vLLM should support per-sequence
proposal lengths.
"""
def
__init__
(
self
,
worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
,
max_proposal_len
:
Optional
[
int
]
=
None
,
):
self
.
_worker
=
worker
self
.
_device
=
device
self
.
max_proposal_len
=
max_proposal_len
self
.
_vocab_size
=
vocab_size
def
get_proposals
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
proposal_len
:
int
,
)
->
SpeculativeProposals
:
"""Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during
speculation.
"""
# Split speculative- and non-speculative- sequences.
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
=
self
.
_split_by_max_model_len
(
seq_group_metadata_list
,
proposal_len
)
if
nonzero_proposal_len_seqs
:
# Speculate tokens using the draft worker for the speculative
# sequences.
# If sampler_transposed is true, then maybe_sampler_output's
# token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len]
# in batch size list
maybe_sampler_output
,
transposed
=
self
.
_worker
.
sampler_output
(
seq_group_metadata_list
=
nonzero_proposal_len_seqs
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
sample_len
=
proposal_len
,
)
else
:
# If no sequences can be speculated, set sampler output to None.
maybe_sampler_output
=
None
transposed
=
False
# Combine speculative- and non-speculative sequences into the same
# representation.
proposal_tokens
,
proposal_probs
,
proposal_lens
=
self
.
_merge_outputs
(
batch_size
=
len
(
seq_group_metadata_list
),
proposal_len
=
proposal_len
,
maybe_sampler_output
=
maybe_sampler_output
,
proposal_lens
=
proposal_lens
,
nonzero_proposal_len_indices
=
nonzero_proposal_len_indices
,
sampler_transposed
=
transposed
,
)
proposals
=
SpeculativeProposals
(
proposal_token_ids
=
proposal_tokens
,
proposal_probs
=
proposal_probs
,
proposal_lens
=
proposal_lens
,
)
return
proposals
def
_split_by_max_model_len
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
SequenceGroupMetadata
],
List
[
int
]]:
"""Determine which sequences would exceed the max model length."""
proposal_lens
:
List
[
int
]
=
[]
nonzero_proposal_len_seqs
:
List
[
SequenceGroupMetadata
]
=
[]
nonzero_proposal_len_indices
:
List
[
int
]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_data
=
next
(
iter
(
seq_group_metadata
.
seq_data
.
values
()))
seq_len
=
seq_data
.
get_len
()
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
if
(
self
.
max_proposal_len
is
None
or
seq_len
+
proposal_len
<
self
.
max_proposal_len
):
proposal_lens
.
append
(
proposal_len
)
nonzero_proposal_len_seqs
.
append
(
seq_group_metadata
)
nonzero_proposal_len_indices
.
append
(
i
)
else
:
proposal_lens
.
append
(
0
)
return
(
proposal_lens
,
nonzero_proposal_len_seqs
,
nonzero_proposal_len_indices
,
)
def
_merge_outputs
(
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
tensor
,
torch
.
Tensor
]:
"""After speculations are produced, merge the speculation results with
the skipped sequences.
"""
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
proposal_len
,
),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
proposal_len
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
proposal_tokens
,
proposal_probs
=
(
entire_proposal_tokens
,
entire_proposal_probs
,
)
proposal_lens_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_lens_tensor
[
nonzero_proposal_len_indices
]
=
proposal_len
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
vllm/spec_decode/util.py
View file @
b38e42fb
...
...
@@ -50,9 +50,12 @@ def split_batch_by_proposal_len(
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
we need do additional tensor transpose logic here.
Returns:
sampled_token_ids: torch.Tensor
shape: [batch_size, len(sampler_output_list)]
...
...
@@ -68,7 +71,10 @@ def sampler_output_to_torch(
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
).
transpose
(
0
,
1
)
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
...
...
@@ -77,7 +83,9 @@ def sampler_output_to_torch(
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
).
transpose
(
0
,
1
)
)
if
sampler_transposed
:
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
return
sampled_token_ids
,
sampled_token_probs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment