Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15702038
"vscode:/vscode.git/clone" did not exist on "9103ed16967390f3bbd6df104dcd162db43d3148"
Unverified
Commit
15702038
authored
Oct 01, 2024
by
Lily Liu
Committed by
GitHub
Oct 01, 2024
Browse files
[Spec Decode] (1/2) Remove batch expansion (#8839)
parent
22f5851b
Changes
29
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
356 additions
and
36 deletions
+356
-36
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-1
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+1
-1
tests/spec_decode/e2e/test_integration.py
tests/spec_decode/e2e/test_integration.py
+44
-0
tests/spec_decode/e2e/test_medusa_correctness.py
tests/spec_decode/e2e/test_medusa_correctness.py
+49
-0
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+43
-0
tests/spec_decode/e2e/test_ngram_correctness.py
tests/spec_decode/e2e/test_ngram_correctness.py
+46
-0
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+0
-1
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+65
-0
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+5
-4
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+16
-13
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+29
-7
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+0
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+8
-0
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+2
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+6
-0
vllm/config.py
vllm/config.py
+7
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+8
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+14
-4
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+6
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
15702038
...
@@ -208,7 +208,7 @@ steps:
...
@@ -208,7 +208,7 @@ steps:
-
tests/spec_decode
-
tests/spec_decode
commands
:
commands
:
-
pytest -v -s spec_decode/e2e/test_multistep_correctness.py
-
pytest -v -s spec_decode/e2e/test_multistep_correctness.py
-
pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN
pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
-
label
:
LoRA Test %N
# 15min each
-
label
:
LoRA Test %N
# 15min each
mirror_hardwares
:
[
amd
]
mirror_hardwares
:
[
amd
]
...
...
tests/samplers/test_sampler.py
View file @
15702038
...
@@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_group_metadata_list
,
seq_lens
=
seq_lens
if
seq_lens
else
None
,
seq_lens
=
seq_lens
if
seq_lens
else
None
,
query_lens
=
seq_lens
if
seq_lens
else
Non
e
,
query_lens
=
seq_lens
if
seq_lens
else
[
1
]
*
batch_siz
e
,
device
=
device
,
device
=
device
,
pin_memory
=
is_pin_memory_available
())
pin_memory
=
is_pin_memory_available
())
# the logits tensor is modified in-place by the sampler
# the logits tensor is modified in-place by the sampler
...
...
tests/spec_decode/e2e/test_integration.py
View file @
15702038
...
@@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
...
@@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
max_output_len
=
32
,
max_output_len
=
32
,
seed
=
seed
,
seed
=
seed
,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
MAIN_MODEL
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"speculative_model"
:
"JackFram/llama-68m"
,
"num_speculative_tokens"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_disable_mqa_scorer"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_medusa_correctness.py
View file @
15702038
...
@@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model_name"
:
MAIN_MODEL
,
"speculative_model"
:
SPEC_MODEL
,
"num_speculative_tokens"
:
MAX_SPEC_TOKENS
,
"speculative_disable_by_batch_size"
:
4
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_disable_mqa_scorer"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
import
pytest
import
pytest
pytest
.
main
([
__file__
])
pytest
.
main
([
__file__
])
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
15702038
...
@@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len
=
output_len
,
max_output_len
=
output_len
,
seed
=
seed
,
seed
=
seed
,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
MAIN_MODEL
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"speculative_model"
:
SPEC_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_disable_mqa_scorer"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mqa_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/e2e/test_ngram_correctness.py
View file @
15702038
...
@@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
...
@@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
max_output_len
=
output_len
,
max_output_len
=
output_len
,
seed
=
seed
,
seed
=
seed
,
temperature
=
0.0
)
temperature
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model_name"
:
"JackFram/llama-68m"
,
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
"speculative_model"
:
"[ngram]"
,
"num_speculative_tokens"
:
5
,
"ngram_prompt_lookup_max"
:
3
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"speculative_disable_mqa_scorer"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
5
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
# Use smaller output len for fast test.
32
,
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_ngram_scorer
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
:
int
,
output_len
:
int
,
seed
:
int
):
"""Verify that ngram speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
run_equality_correctness_test
(
vllm_runner
,
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
test_llm_kwargs
,
batch_size
,
max_output_len
=
output_len
,
seed
=
seed
,
temperature
=
0.0
)
tests/spec_decode/test_multi_step_worker.py
View file @
15702038
...
@@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
...
@@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
block_size
,
block_size
,
num_gpu_blocks
,
num_gpu_blocks
,
seed
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
)
worker
=
create_worker
(
worker
=
create_worker
(
...
...
tests/spec_decode/test_scorer.py
0 → 100644
View file @
15702038
import
pytest
import
torch
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
,
SpeculativeScores
from
vllm.spec_decode.mqa_scorer
import
MQAScorer
from
vllm.worker.worker
import
Worker
from
.utils
import
create_batch
,
create_worker
def
create_proposal
(
batch_size
:
int
,
propose_len
:
int
,
vocab_size
:
int
,
device
:
str
)
->
SpeculativeProposals
:
proposal_probs
=
torch
.
rand
((
batch_size
,
propose_len
,
vocab_size
),
device
=
device
)
proposal_token_ids
=
torch
.
argmax
(
proposal_probs
,
dim
=-
1
)
proposal_lens
=
torch
.
tensor
([
propose_len
]
*
batch_size
,
device
=
device
)
return
SpeculativeProposals
(
proposal_token_ids
,
proposal_probs
,
proposal_lens
)
def
assert_score_equal
(
score1
:
SpeculativeScores
,
score2
:
SpeculativeScores
)
->
None
:
assert
torch
.
allclose
(
score1
.
probs
,
score2
.
probs
)
assert
torch
.
allclose
(
score1
.
logprobs
,
score2
.
logprobs
)
assert
torch
.
equal
(
score1
.
token_ids
,
score2
.
token_ids
)
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
'facebook/opt-125m'
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
'propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
def
test_scoroer
(
model_name
:
str
,
batch_size
:
int
,
propose_len
:
int
,
device
:
str
)
->
None
:
"""
Compare the batch expansion scorer and mqa scorer return the same score
"""
seed
=
0
block_size
=
32
num_gpu_blocks
=
2048
//
block_size
scorer_worker
=
create_worker
(
Worker
,
model_name
,
block_size
,
num_gpu_blocks
,
seed
)
scorer_worker
.
model_runner
.
model
.
sampler
.
include_gpu_probs_tensor
=
True
scorer_worker
.
model_runner
.
model
.
sampler
.
\
should_modify_greedy_probs_inplace
=
True
vocab_size
=
scorer_worker
.
vocab_size
proposals
=
create_proposal
(
batch_size
,
propose_len
,
vocab_size
,
device
)
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
propose_len
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
)
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
num_lookahead_slots
=
propose_len
)
batch_expansion_scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
,
device
,
vocab_size
)
batch_expansion_score
=
batch_expansion_scorer
.
score_proposals
(
requests
,
proposals
)
mqa_scorer
=
MQAScorer
(
scorer_worker
,
device
,
vocab_size
)
mqa_score
=
mqa_scorer
.
score_proposals
(
requests
,
proposals
)
assert_score_equal
(
batch_expansion_score
,
mqa_score
)
tests/spec_decode/test_spec_decode_worker.py
View file @
15702038
...
@@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
...
@@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
[
"rejection_sampler"
,
"typical_acceptance_sampler"
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_correctly_calls_target_model
(
k
:
int
,
batch_size
:
int
,
def
test_
batch_expansion_
correctly_calls_target_model
(
acceptance_sampler_method
:
str
):
k
:
int
,
batch_size
:
int
,
acceptance_sampler_method
:
str
):
"""Verify SpecDecodeWorker calls the target model with correct
"""Verify SpecDecodeWorker calls the target model with correct
inputs. Everything else is mocked out.
inputs
with batch expansion
. Everything else is mocked out.
"""
"""
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
draft_worker
=
mock_worker
(
cls
=
MultiStepWorker
,
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
target_worker
=
mock_worker
(
use_spec
=
False
)
...
@@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
...
@@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
target_worker
,
target_worker
,
mock_spec_decode_sampler
(
acceptance_sampler_method
),
mock_spec_decode_sampler
(
acceptance_sampler_method
),
disable_logprobs
=
False
,
disable_logprobs
=
False
,
metrics_collector
=
metrics_collector
)
metrics_collector
=
metrics_collector
,
disable_mqa_scorer
=
True
)
worker
.
init_device
()
worker
.
init_device
()
vocab_size
=
32_000
vocab_size
=
32_000
...
...
tests/spec_decode/utils.py
View file @
15702038
...
@@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
...
@@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
for
i
,
final_len
in
enumerate
(
final_prompt_lens
)
for
i
,
final_len
in
enumerate
(
final_prompt_lens
)
}
}
return
[
seq_grou_metadata_list
=
[]
for
i
,
(
prompt_token_ids
,
cont_token_ids
)
in
enumerate
(
zip
(
prompts
,
continuations
)):
data
=
SequenceData
.
from_seqs
(
prompt_token_ids
,
cont_token_ids
)
data
.
update_num_computed_tokens
(
len
(
prompt_token_ids
)
+
len
(
cont_token_ids
)
-
1
)
seq_data
=
{
i
:
data
}
seq_grou_metadata_list
.
append
(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
str
(
i
),
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
{
seq_data
=
seq_data
,
i
:
SequenceData
.
from_seqs
(
prompt_token_ids
[:],
sampling_params
=
SamplingParams
(
temperature
=
0.0
),
cont_token_ids
[:]),
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
block_tables
=
{
i
:
block_allocations
[
i
][:]},
)
for
i
,
(
prompt_token_ids
,
))
cont_token_ids
)
in
enumerate
(
zip
(
prompts
,
continuations
))
return
seq_grou_metadata_list
]
def
assert_logprobs_dict_allclose
(
def
assert_logprobs_dict_allclose
(
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
15702038
...
@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
_cached_prefill_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
"BlocksparseFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
_cached_decode_metadata
:
Optional
[
...
...
vllm/attention/backends/flash_attn.py
View file @
15702038
...
@@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata):
# |-------------------- seq_len ---------------------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# |-- query_len ---|
# Maximum query length in the batch.
None for decoding.
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
max_query_len
:
Optional
[
int
]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
# requests only.
max_prefill_seq_len
:
int
max_prefill_seq_len
:
int
...
@@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
decode_query_len
=
0
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
max_decode_seq_len
=
0
,
...
@@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
decode_query_len
=
self
.
decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
query_start_loc
=
None
,
...
@@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder(
self
.
num_prefill_tokens
+=
token_len
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
...
@@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder(
...
@@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder(
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
decode_query_len
=
max
(
decode_query_lens
)
else
:
decode_query_len
=
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
num_decode_tokens
=
self
.
num_decode_tokens
...
@@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder(
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
decode_query_len
=
decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
...
@@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl):
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
_
,
num_head
,
head_dim
=
decode_query
.
shape
decode_query
=
decode_query
.
reshape
(
-
1
,
decode_meta
.
decode_query_len
,
num_head
,
head_dim
)
decode_output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
)
,
decode_query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
block_table
=
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
...
@@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
)
.
squeeze
(
1
)
)
if
prefill_output
is
None
:
if
prefill_output
is
None
:
assert
decode_output
is
not
None
assert
decode_output
is
not
None
...
@@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl):
if
decode_output
is
None
:
if
decode_output
is
None
:
assert
prefill_output
is
not
None
assert
prefill_output
is
not
None
return
prefill_output
.
view
(
num_prefill_tokens
,
hidden_size
)
return
prefill_output
.
view
(
num_prefill_tokens
,
hidden_size
)
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_query_len
==
1
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flashinfer.py
View file @
15702038
...
@@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
num_decode_tokens
=
self
.
num_decode_tokens
...
@@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
device
,
device
=
device
,
)
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
device
is
not
None
assert
device
is
not
None
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
15702038
...
@@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
# (batch_size,) A tensor of context lengths (tokens that are computed
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
...
...
vllm/attention/backends/utils.py
View file @
15702038
...
@@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState):
...
@@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState):
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
None
,
max_query_len
=
1
,
decode_query_len
=
1
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
query_start_loc
=
None
,
query_start_loc
=
None
,
...
...
vllm/attention/backends/xformers.py
View file @
15702038
...
@@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding.
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
=
None
max_query_len
:
Optional
[
int
]
=
None
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
# is [4, 6], it is [0, 4, 10].
...
...
vllm/config.py
View file @
15702038
...
@@ -1116,6 +1116,7 @@ class SpeculativeConfig:
...
@@ -1116,6 +1116,7 @@ class SpeculativeConfig:
speculative_model_quantization
:
Optional
[
str
],
speculative_model_quantization
:
Optional
[
str
],
speculative_draft_tensor_parallel_size
:
Optional
[
int
],
speculative_draft_tensor_parallel_size
:
Optional
[
int
],
num_speculative_tokens
:
Optional
[
int
],
num_speculative_tokens
:
Optional
[
int
],
speculative_disable_mqa_scorer
:
Optional
[
bool
],
speculative_max_model_len
:
Optional
[
int
],
speculative_max_model_len
:
Optional
[
int
],
enable_chunked_prefill
:
bool
,
enable_chunked_prefill
:
bool
,
use_v2_block_manager
:
bool
,
use_v2_block_manager
:
bool
,
...
@@ -1150,6 +1151,9 @@ class SpeculativeConfig:
...
@@ -1150,6 +1151,9 @@ class SpeculativeConfig:
num_speculative_tokens (Optional[int]): The number of speculative
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
model config if present, otherwise is required.
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
scorer for the speculative model and fall back to batch
expansion for scoring.
speculative_max_model_len (Optional[int]): The maximum model len of
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
the speculative model. Used when testing the ability to skip
speculation for some sequences.
speculation for some sequences.
...
@@ -1304,6 +1308,7 @@ class SpeculativeConfig:
...
@@ -1304,6 +1308,7 @@ class SpeculativeConfig:
draft_model_config
,
draft_model_config
,
draft_parallel_config
,
draft_parallel_config
,
num_speculative_tokens
,
num_speculative_tokens
,
speculative_disable_mqa_scorer
,
speculative_disable_by_batch_size
,
speculative_disable_by_batch_size
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
,
ngram_prompt_lookup_min
,
...
@@ -1400,6 +1405,7 @@ class SpeculativeConfig:
...
@@ -1400,6 +1405,7 @@ class SpeculativeConfig:
draft_model_config
:
ModelConfig
,
draft_model_config
:
ModelConfig
,
draft_parallel_config
:
ParallelConfig
,
draft_parallel_config
:
ParallelConfig
,
num_speculative_tokens
:
int
,
num_speculative_tokens
:
int
,
speculative_disable_mqa_scorer
:
Optional
[
bool
],
speculative_disable_by_batch_size
:
Optional
[
int
],
speculative_disable_by_batch_size
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_max
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
ngram_prompt_lookup_min
:
Optional
[
int
],
...
@@ -1446,6 +1452,7 @@ class SpeculativeConfig:
...
@@ -1446,6 +1452,7 @@ class SpeculativeConfig:
self
.
draft_model_config
=
draft_model_config
self
.
draft_model_config
=
draft_model_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
draft_parallel_config
=
draft_parallel_config
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
num_speculative_tokens
=
num_speculative_tokens
self
.
speculative_disable_mqa_scorer
=
speculative_disable_mqa_scorer
self
.
speculative_disable_by_batch_size
=
\
self
.
speculative_disable_by_batch_size
=
\
speculative_disable_by_batch_size
speculative_disable_by_batch_size
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
self
.
ngram_prompt_lookup_max
=
ngram_prompt_lookup_max
or
0
...
...
vllm/engine/arg_utils.py
View file @
15702038
...
@@ -162,6 +162,7 @@ class EngineArgs:
...
@@ -162,6 +162,7 @@ class EngineArgs:
speculative_model_quantization
:
Optional
[
str
]
=
None
speculative_model_quantization
:
Optional
[
str
]
=
None
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
=
None
speculative_draft_tensor_parallel_size
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_disable_mqa_scorer
:
Optional
[
bool
]
=
False
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
speculative_disable_by_batch_size
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
...
@@ -640,6 +641,12 @@ class EngineArgs:
...
@@ -640,6 +641,12 @@ class EngineArgs:
default
=
EngineArgs
.
num_speculative_tokens
,
default
=
EngineArgs
.
num_speculative_tokens
,
help
=
'The number of speculative tokens to sample from '
help
=
'The number of speculative tokens to sample from '
'the draft model in speculative decoding.'
)
'the draft model in speculative decoding.'
)
parser
.
add_argument
(
'--speculative-disable-mqa-scorer'
,
action
=
'store_true'
,
help
=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--speculative-draft-tensor-parallel-size'
,
'--speculative-draft-tensor-parallel-size'
,
'-spec-draft-tp'
,
'-spec-draft-tp'
,
...
@@ -970,6 +977,7 @@ class EngineArgs:
...
@@ -970,6 +977,7 @@ class EngineArgs:
speculative_draft_tensor_parallel_size
=
\
speculative_draft_tensor_parallel_size
=
\
self
.
speculative_draft_tensor_parallel_size
,
self
.
speculative_draft_tensor_parallel_size
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
num_speculative_tokens
=
self
.
num_speculative_tokens
,
speculative_disable_mqa_scorer
=
self
.
speculative_disable_mqa_scorer
,
speculative_disable_by_batch_size
=
self
.
speculative_disable_by_batch_size
=
self
.
speculative_disable_by_batch_size
,
speculative_disable_by_batch_size
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
...
...
vllm/engine/llm_engine.py
View file @
15702038
...
@@ -1110,6 +1110,8 @@ class LLMEngine:
...
@@ -1110,6 +1110,8 @@ class LLMEngine:
update_prefill_num_computed_tokens
(
seq_group
,
seq_group_meta
,
update_prefill_num_computed_tokens
(
seq_group
,
seq_group_meta
,
len
(
output
),
len
(
output
),
is_first_step_output
)
is_first_step_output
)
elif
not
is_async
:
seq_group
.
update_num_computed_tokens
(
1
)
if
outputs
:
if
outputs
:
for
o
in
outputs
:
for
o
in
outputs
:
...
@@ -1133,8 +1135,16 @@ class LLMEngine:
...
@@ -1133,8 +1135,16 @@ class LLMEngine:
else
:
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
if
seq_group_meta
.
do_sample
:
if
seq_group_meta
.
do_sample
:
self
.
output_processor
.
process_outputs
(
output_token_num
=
self
.
output_processor
.
process_outputs
(
seq_group
,
output
,
is_async
)
seq_group
,
output
,
is_async
)
if
self
.
speculative_config
:
# We -1 here because we always
# (w/o speculative decoding) add the number of
# computed tokens by one in the decoding phase.
# Therefore, we remove that one token that
# is already added.
seq_group
.
update_num_computed_tokens
(
output_token_num
-
1
)
if
seq_group
.
is_finished
():
if
seq_group
.
is_finished
():
finished_now
.
append
(
i
)
finished_now
.
append
(
i
)
...
@@ -1251,11 +1261,12 @@ class LLMEngine:
...
@@ -1251,11 +1261,12 @@ class LLMEngine:
# decodes after the very first step. Therefore,
# decodes after the very first step. Therefore,
# we skip the update to the num_computed_tokens
# we skip the update to the num_computed_tokens
# here.
# here.
pass
seq_group
.
update_num_computed_tokens
(
1
)
else
:
else
:
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
seq_group_metadata
.
token_chunk_size
)
seq_group_metadata
.
token_chunk_size
)
else
:
seq_group
.
update_num_computed_tokens
(
1
)
if
seq_group_metadata
.
do_sample
:
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
"Async output processor expects a single sample"
"Async output processor expects a single sample"
...
@@ -1266,7 +1277,6 @@ class LLMEngine:
...
@@ -1266,7 +1277,6 @@ class LLMEngine:
assert
len
(
seq_group
.
seqs
)
==
1
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
seq_group
.
update_num_computed_tokens
(
1
)
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
"""Performs one decoding iteration and returns newly generated results.
...
...
vllm/engine/output_processor/interfaces.py
View file @
15702038
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
Optional
from
vllm.config
import
SchedulerConfig
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.core.scheduler
import
Scheduler
...
@@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC):
...
@@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC):
@
abstractmethod
@
abstractmethod
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
],
outputs
:
List
[
SequenceGroupOutput
],
is_async
:
bool
)
->
None
:
is_async
:
bool
)
->
Optional
[
int
]
:
"""Process new token ids for the sequence group. Handles logic such as
"""Process new token ids for the sequence group. Handles logic such as
detokenization, stop checking, and freeing/forking sequences in the
detokenization, stop checking, and freeing/forking sequences in the
scheduler.
scheduler.
Return the number of new tokens generated in the sequence group.
The returned value is optional because it is only used for
speculative decoding mqa scorer.
"""
"""
pass
pass
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment