Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1856aff4
Unverified
Commit
1856aff4
authored
Aug 25, 2024
by
Nick Hill
Committed by
GitHub
Aug 25, 2024
Browse files
[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)
parent
70c094ad
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
118 additions
and
125 deletions
+118
-125
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+13
-18
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+79
-64
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+9
-16
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+1
-1
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+16
-26
No files found.
tests/spec_decode/test_utils.py
View file @
1856aff4
...
...
@@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
def
test_filter_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
0
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
select_proposal_len_zero
=
True
)
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
expected_groups
=
[
fake_sequence_group_metadata
[
0
],
fake_sequence_group_metadata
[
2
]
...
...
@@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
def
test_filter_non_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
2
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
select_proposal_len_zero
=
False
)
(
filtered_groups
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
expected_groups
=
[
fake_sequence_group_metadata
[
1
],
fake_sequence_group_metadata
[
2
]
...
...
@@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
def
test_empty_inputs
():
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
[],
[],
select_proposal_len_zero
=
True
)
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
([],
[])
assert
filtered_groups
==
[]
assert
indices
==
[]
...
...
@@ -95,10 +92,9 @@ def test_empty_inputs():
def
test_all_zero_with_non_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
0
,
0
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
select_proposal_len_zero
=
False
)
(
filtered_groups
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
assert
filtered_groups
==
[]
assert
indices
==
[]
...
...
@@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
def
test_all_non_zero_with_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
1
,
1
,
1
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
select_proposal_len_zero
=
True
)
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
)
assert
filtered_groups
==
[]
assert
indices
==
[]
...
...
vllm/spec_decode/batch_expansion.py
View file @
1856aff4
...
...
@@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
split_batch_by_proposal_len
)
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
from
vllm.worker.worker_base
import
WorkerBase
SeqId
=
int
...
...
@@ -88,8 +87,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
(
all_tokens
,
all_probs
,
spec_logprobs
,
all_hidden_states
)
=
self
.
_contract_batch
(
if
not
non_spec_indices
:
# All sequence groups in batch have spec decoding enabled
contracted
=
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
)
else
:
# Batch has a mix of spec decode enabled and disabled seq groups
contracted
=
self
.
_contract_batch
(
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
...
...
@@ -99,6 +105,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
k
=
execute_model_req
.
num_lookahead_slots
,
)
all_tokens
,
all_probs
,
spec_logprobs
,
all_hidden_states
=
contracted
return
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
...
...
@@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs
,
spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
False
)
non_spec_seqs
,
non_spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
True
)
(
spec_seqs
,
spec_indices
),
(
non_spec_seqs
,
non_spec_indices
)
=
\
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
seq_group_metadata_list
=
spec_seqs
,
...
...
@@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs
,
_
=
non_spec_target_token_ids
.
shape
non_spec_expanded_bs
=
len
(
non_spec_target_token_ids
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
...
...
@@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
spec_expanded_bs
,
k
+
1
,
target_hidden_states
.
shape
[
-
1
])
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
)
...
...
@@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states
=
None
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
all_tokens
[
non_spec_indices
,
:
1
]
=
\
non_spec_target_token_ids
.
unsqueeze
(
1
)
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_probs
.
unsqueeze
(
1
)
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_logprobs
.
unsqueeze
(
1
)
if
all_hidden_states
is
not
None
:
all_hidden_states
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_hidden_states
assert
non_spec_target_hidden_states
is
not
None
all_hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_hidden_states
.
unsqueeze
(
1
)
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
if
all_hidden_states
is
not
None
:
all_hidden_states
[
spec_indices
]
=
target_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
def
_contract_batch_all_spec
(
self
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs
,
k
=
proposals
.
proposal_token_ids
.
shape
# Reshape tensors to original batch size
target_token_ids
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
contracted_bs
,
k
+
1
)
target_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
target_probs
.
shape
)
target_hidden_states
=
target_sampler_output
.
hidden_states
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
def
_create_scoring_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
...
@@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size
=
1
,
)
@
staticmethod
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
...
@@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes
=
[
num_scoring_tokens
,
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
]
split_sizes
=
(
num_scoring_tokens
,
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
)
(
spec_probs
,
non_spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
...
...
@@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
spec_hidden_states
,
non_spec_hidden_states
=
None
,
None
# Convert scores to tensors.
sampler_output
.
sampled_token_probs
=
spec_probs
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
sampler_output
.
logprobs
=
spec_logprobs
sampler_output
.
hidden_states
=
spec_hidden_states
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
sampler_output
.
logprobs
=
non_spec_logprobs
sampler_output
.
hidden_states
=
non_spec_hidden_states
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
=
sampler_output_to_torch
(
[
sampler_output
],
True
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
non_spec_logprobs
,
non_spec_hidden_states
)
@
staticmethod
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
...
...
@@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""
return
count
(
start
=
max
(
seq_ids
)
+
1
)
@
staticmethod
def
_get_token_ids_to_score
(
self
,
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
)
->
List
[
List
[
TokenId
]]:
"""Given an int tensor of proposal token ids, return a list of
...
...
@@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids
:
List
[
TokenId
]
=
[]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
.
extend
([
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
))
])
token_ids_to_score
.
extend
(
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
)))
return
token_ids_to_score
vllm/spec_decode/spec_decode_worker.py
View file @
1856aff4
...
...
@@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# 3. No request: There are no requests in the batch.
# 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# are called normally.
no_spec
=
num_lookahead_slots
==
0
or
len
(
execute_model_req
.
seq_group_metadata_list
)
==
0
or
disable_all_speculation
no_spec
=
num_lookahead_slots
==
0
or
disable_all_speculation
or
all
(
sgm
.
num_speculative_tokens
==
0
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
...
...
@@ -415,11 +416,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
,
execute_model_req
:
ExecuteModelRequest
)
->
bool
:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all_speculation
=
(
execute_model_req
.
running_queue_size
>=
return
(
execute_model_req
.
running_queue_size
>=
self
.
disable_by_batch_size
)
return
disable_all_speculation
def
_maybe_disable_speculative_tokens
(
self
,
disable_all_speculation
:
bool
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
...
...
@@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
_
,
spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
False
)
_
,
non_spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
True
)
(
_
,
spec_indices
),
(
_
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
)
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, excluding bonus token.
...
...
vllm/spec_decode/top1_proposer.py
View file @
1856aff4
...
...
@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# If max_proposal_len is defined, then we shall no
t
exceed this
# quota for nonzero_proposal
new_k
=
0
if
(
self
.
max_proposal_len
is
None
...
...
vllm/spec_decode/util.py
View file @
1856aff4
import
time
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
import
torch
...
...
@@ -98,33 +98,26 @@ def create_sequence_group_output(
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
select_proposal_len_zero
:
bool
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]:
proposal_lens
:
List
[
int
],
)
->
Tuple
[
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]],
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]]:
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
"""
if
select_proposal_len_zero
:
predicate
=
lambda
proposal_len
:
proposal_len
==
0
else
:
predicate
=
lambda
proposal_len
:
proposal_len
!=
0
indices
=
[
i
for
i
,
(
_
,
proposal_len
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
proposal_lens
))
if
predicate
(
proposal_len
)
]
seq_groups
=
[
seq_group
for
seq_group
,
proposal_len
in
zip
(
seq_group_metadata_list
,
proposal_lens
)
if
predicate
(
proposal_len
)
]
return
seq_groups
,
indices
nonzero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
zero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
for
i
,
(
seq_group
,
proposal_len
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
proposal_lens
)):
seq_groups
,
indices
=
nonzero_lists
if
proposal_len
else
zero_lists
seq_groups
.
append
(
seq_group
)
indices
.
append
(
i
)
return
nonzero_lists
,
zero_lists
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
sampler_output_list
:
Sequence
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Utility function which converts a list of SamplerOutput to tensors.
...
...
@@ -148,18 +141,12 @@ def sampler_output_to_torch(
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
[
...
...
@@ -168,7 +155,10 @@ def sampler_output_to_torch(
],
dim
=
0
,
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
if
sampler_output_list
[
0
].
hidden_states
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment