Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
312f7612
Unverified
Commit
312f7612
authored
Aug 20, 2024
by
Abhinav Goyal
Committed by
GitHub
Aug 19, 2024
Browse files
[Speculative Decoding] Fixing hidden states handling in batch expansion (#7508)
parent
e54ebc2f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
139 additions
and
41 deletions
+139
-41
tests/spec_decode/e2e/conftest.py
tests/spec_decode/e2e/conftest.py
+16
-9
tests/spec_decode/e2e/test_mlp_correctness.py
tests/spec_decode/e2e/test_mlp_correctness.py
+42
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+60
-26
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+2
-3
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+1
-1
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+18
-2
No files found.
tests/spec_decode/e2e/conftest.py
View file @
312f7612
...
@@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
...
@@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted
=
ensure_all_accepted
)
ensure_all_accepted
=
ensure_all_accepted
)
def
run_equality_correctness_test
(
baseline_llm_generator
,
def
run_equality_correctness_test
(
test_llm_generator
,
baseline_llm_generator
,
batch_size
,
test_llm_generator
,
max_output_len
,
batch_size
,
force_output_len
:
bool
,
max_output_len
,
temperature
:
float
,
force_output_len
:
bool
,
seeded
:
bool
,
temperature
:
float
,
print_tokens
:
bool
=
False
,
seeded
:
bool
,
ensure_all_accepted
:
bool
=
False
):
print_tokens
:
bool
=
False
,
ensure_all_accepted
:
bool
=
False
,
expected_acceptance_rate
:
Optional
[
float
]
=
None
):
"""Helper method that compares the outputs of both the baseline LLM and
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
the same when temperature is zero (or when temperature is > 0 and seeded).
...
@@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
...
@@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
print
(
f
'
{
i
=
}
{
spec_token_ids
=
}
'
)
assert
baseline_token_ids
==
spec_token_ids
assert
baseline_token_ids
==
spec_token_ids
print
(
f
'
{
acceptance_rate
=
}
'
)
if
ensure_all_accepted
:
if
ensure_all_accepted
:
assert
acceptance_rate
==
1.0
assert
acceptance_rate
==
1.0
if
expected_acceptance_rate
is
not
None
:
assert
acceptance_rate
>=
expected_acceptance_rate
-
1e-2
tests/spec_decode/e2e/test_mlp_correctness.py
View file @
312f7612
...
@@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
...
@@ -82,6 +82,48 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len
=
True
)
force_output_len
=
True
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Skip cuda graph recording for fast test.
"enforce_eager"
:
True
,
# Required for spec decode.
"use_v2_block_manager"
:
True
,
# Print spec metrics.
"disable_log_stats"
:
False
,
# Precision
"dtype"
:
PRECISION
,
# Main model
"model"
:
MAIN_MODEL
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[
{
"speculative_model"
:
SPEC_MODEL
,
},
])
@
pytest
.
mark
.
parametrize
(
"output_len"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
32
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_mlp_e2e_acceptance_rate
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
:
int
,
output_len
:
int
):
"""Verify acceptance rate with different batch size and large output
length."""
run_equality_correctness_test
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
max_output_len
=
output_len
,
temperature
=
0.0
,
seeded
=
True
,
force_output_len
=
True
,
expected_acceptance_rate
=
0.48
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
[{
[{
...
...
vllm/spec_decode/batch_expansion.py
View file @
312f7612
from
array
import
array
from
array
import
array
from
itertools
import
chain
,
count
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Tuple
from
typing
import
Iterator
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -88,21 +88,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -88,21 +88,22 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
,
spec_logprobs
=
self
.
_contract_batch
(
(
all_tokens
,
all_probs
,
spec_logprobs
,
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
all_hidden_states
)
=
self
.
_contract_batch
(
target_sampler_output
=
target_sampler_output
,
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
proposals
=
proposals
,
target_sampler_output
=
target_sampler_output
,
num_scoring_tokens
=
num_scoring_tokens
,
proposals
=
proposals
,
non_spec_indices
=
non_spec_indices
,
num_scoring_tokens
=
num_scoring_tokens
,
spec_indices
=
spec_indices
,
non_spec_indices
=
non_spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
spec_indices
=
spec_indices
,
)
k
=
execute_model_req
.
num_lookahead_slots
,
)
return
SpeculativeScores
(
return
SpeculativeScores
(
probs
=
all_probs
,
probs
=
all_probs
,
token_ids
=
all_tokens
,
token_ids
=
all_tokens
,
logprobs
=
spec_logprobs
,
logprobs
=
spec_logprobs
,
hidden_states
=
target_sampler_output
.
hidden_states
,
hidden_states
=
all_
hidden_states
,
)
)
def
_expand_batch
(
def
_expand_batch
(
...
@@ -145,10 +146,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -145,10 +146,11 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens
)
num_scoring_tokens
)
def
_contract_batch
(
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
SamplerOutput
,
self
,
contracted_bs
:
int
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
This maps the scores of speculative tokens back to their original
sequences.
sequences.
...
@@ -156,9 +158,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -156,9 +158,10 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
contracted_bs is the original batch size, and the batch size that the
contracted_bs is the original batch size, and the batch size that the
target_sampler_output will be contracted to.
target_sampler_output will be contracted to.
"""
"""
(
target_token_ids
,
target_probs
,
target_logprobs
,
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
self
.
_split_scoring_output
(
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
target_sampler_output
,
num_scoring_tokens
)
# Map distinct sequences used to score each token
# Map distinct sequences used to score each token
...
@@ -176,23 +179,40 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -176,23 +179,40 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
self
.
_vocab_size
)
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
reshape
(
target_probs
.
shape
)
target_logprobs
=
target_logprobs
.
reshape
(
target_probs
.
shape
)
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
spec_expanded_bs
,
k
+
1
,
target_hidden_states
.
shape
[
-
1
])
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
)
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
fill_value
=-
float
(
"inf"
))
fill_value
=-
float
(
"inf"
))
if
target_sampler_output
.
hidden_states
is
not
None
:
all_hidden_states
=
target_hidden_states
.
new_zeros
(
size
=
(
contracted_bs
,
k
+
1
,
target_hidden_states
.
shape
[
-
1
]))
else
:
all_hidden_states
=
None
if
non_spec_indices
:
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
if
all_hidden_states
is
not
None
:
all_hidden_states
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_hidden_states
if
spec_indices
:
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
all_logprobs
[
spec_indices
]
=
target_logprobs
return
all_tokens
,
all_probs
,
all_logprobs
if
all_hidden_states
is
not
None
:
all_hidden_states
[
spec_indices
]
=
target_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
def
_create_scoring_model_input
(
def
_create_scoring_model_input
(
self
,
self
,
...
@@ -327,8 +347,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -327,8 +347,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
def
_split_scoring_output
(
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Split the target model output into speculative and non-speculative
"""Split the target model output into speculative and non-speculative
output.
output.
"""
"""
...
@@ -353,24 +374,37 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -353,24 +374,37 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_logprobs
,
non_spec_logprobs
,
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
)
=
sampler_output
.
logprobs
.
split
(
split_sizes
)
if
sampler_output
.
hidden_states
is
not
None
:
(
spec_hidden_states
,
non_spec_hidden_states
,
)
=
sampler_output
.
hidden_states
.
split
(
split_sizes
)
else
:
spec_hidden_states
,
non_spec_hidden_states
=
None
,
None
# Convert scores to tensors.
# Convert scores to tensors.
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
sampler_output
.
logprobs
=
spec_logprobs
sampler_output
.
logprobs
=
spec_logprobs
(
target_token_ids
,
target_probs
,
sampler_output
.
hidden_states
=
spec_hidden_states
target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
sampler_output
.
logprobs
=
non_spec_logprobs
sampler_output
.
logprobs
=
non_spec_logprobs
sampler_output
.
hidden_states
=
non_spec_hidden_states
(
non_spec_target_token_ids
,
non_spec_target_probs
,
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
)
=
sampler_output_to_torch
([
sampler_output
],
non_spec_target_logprobs
,
True
)
non_spec_target_hidden_states
)
=
sampler_output_to_torch
(
[
sampler_output
],
True
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
non_spec_target_token_ids
,
non_spec_target_probs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_logprobs
)
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
def
_create_target_seq_id_iterator
(
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
312f7612
...
@@ -646,9 +646,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -646,9 +646,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
hidden_states
=
proposal_scores
.
hidden_states
hidden_states
=
proposal_scores
.
hidden_states
if
hidden_states
is
not
None
:
if
hidden_states
is
not
None
:
# Contract hidden states based on accepted tokens
# Contract hidden states based on accepted tokens
hs_size
=
hidden_states
.
shape
[
1
]
hs_size
=
hidden_states
.
shape
[
-
1
]
hidden_states
=
hidden_states
.
reshape
(
-
1
,
max_proposal_len
+
1
,
hs_size
)
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
hs_size
)
...
...
vllm/spec_decode/top1_proposer.py
View file @
312f7612
...
@@ -242,7 +242,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -242,7 +242,7 @@ class Top1Proposer(SpeculativeProposer):
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
sampler_output
=
maybe_sampler_output
proposal_tokens
,
proposal_probs
,
_
=
sampler_output_to_torch
(
proposal_tokens
,
proposal_probs
,
*
_
=
sampler_output_to_torch
(
sampler_output
,
sampler_transposed
)
sampler_output
,
sampler_transposed
)
# Now, reformat the output GPU tensors such that each sequence has
# Now, reformat the output GPU tensors such that each sequence has
...
...
vllm/spec_decode/util.py
View file @
312f7612
...
@@ -123,7 +123,7 @@ def split_batch_by_proposal_len(
...
@@ -123,7 +123,7 @@ def split_batch_by_proposal_len(
def
sampler_output_to_torch
(
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
"""Utility function which converts a list of SamplerOutput to tensors.
"""Utility function which converts a list of SamplerOutput to tensors.
sampler_transposed here is used as the indicator for whether
sampler_transposed here is used as the indicator for whether
...
@@ -169,7 +169,23 @@ def sampler_output_to_torch(
...
@@ -169,7 +169,23 @@ def sampler_output_to_torch(
if
sampler_transposed
:
if
sampler_transposed
:
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
return
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
if
sampler_output_list
[
0
].
hidden_states
is
not
None
:
# shape: [batch_size, num_sampler_output, hidden_dim]
sampled_hidden_states
=
torch
.
stack
(
[
sampler_output
.
hidden_states
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_hidden_states
=
sampled_hidden_states
.
transpose
(
0
,
1
)
else
:
sampled_hidden_states
=
None
return
(
sampled_token_ids
,
sampled_token_probs
,
sampled_token_logprobs
,
sampled_hidden_states
)
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment