Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1856aff4
Unverified
Commit
1856aff4
authored
Aug 25, 2024
by
Nick Hill
Committed by
GitHub
Aug 25, 2024
Browse files
[Spec Decoding] Streamline batch expansion tensor manipulation (#7851)
parent
70c094ad
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
118 additions
and
125 deletions
+118
-125
tests/spec_decode/test_utils.py
tests/spec_decode/test_utils.py
+13
-18
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+79
-64
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+9
-16
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+1
-1
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+16
-26
No files found.
tests/spec_decode/test_utils.py
View file @
1856aff4
...
@@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
...
@@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
def
test_filter_zero_length_proposals
(
fake_sequence_group_metadata
):
def
test_filter_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
0
]
proposal_lens
=
[
0
,
1
,
0
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
_
,
(
filtered_groups
,
fake_sequence_group_metadata
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
proposal_lens
)
select_proposal_len_zero
=
True
)
expected_groups
=
[
expected_groups
=
[
fake_sequence_group_metadata
[
0
],
fake_sequence_group_metadata
[
2
]
fake_sequence_group_metadata
[
0
],
fake_sequence_group_metadata
[
2
]
...
@@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
...
@@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
def
test_filter_non_zero_length_proposals
(
fake_sequence_group_metadata
):
def
test_filter_non_zero_length_proposals
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
1
,
2
]
proposal_lens
=
[
0
,
1
,
2
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
(
filtered_groups
,
fake_sequence_group_metadata
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
proposal_lens
)
select_proposal_len_zero
=
False
)
expected_groups
=
[
expected_groups
=
[
fake_sequence_group_metadata
[
1
],
fake_sequence_group_metadata
[
2
]
fake_sequence_group_metadata
[
1
],
fake_sequence_group_metadata
[
2
]
...
@@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
...
@@ -86,8 +84,7 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
def
test_empty_inputs
():
def
test_empty_inputs
():
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
_
,
(
filtered_groups
,
indices
)
=
split_batch_by_proposal_len
([],
[])
[],
[],
select_proposal_len_zero
=
True
)
assert
filtered_groups
==
[]
assert
filtered_groups
==
[]
assert
indices
==
[]
assert
indices
==
[]
...
@@ -95,10 +92,9 @@ def test_empty_inputs():
...
@@ -95,10 +92,9 @@ def test_empty_inputs():
def
test_all_zero_with_non_zero_filter
(
fake_sequence_group_metadata
):
def
test_all_zero_with_non_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
0
,
0
,
0
]
proposal_lens
=
[
0
,
0
,
0
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
(
filtered_groups
,
fake_sequence_group_metadata
,
indices
),
_
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
proposal_lens
)
select_proposal_len_zero
=
False
)
assert
filtered_groups
==
[]
assert
filtered_groups
==
[]
assert
indices
==
[]
assert
indices
==
[]
...
@@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
...
@@ -106,10 +102,9 @@ def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
def
test_all_non_zero_with_zero_filter
(
fake_sequence_group_metadata
):
def
test_all_non_zero_with_zero_filter
(
fake_sequence_group_metadata
):
proposal_lens
=
[
1
,
1
,
1
]
proposal_lens
=
[
1
,
1
,
1
]
filtered_groups
,
indices
=
split_batch_by_proposal_len
(
_
,
(
filtered_groups
,
fake_sequence_group_metadata
,
indices
)
=
split_batch_by_proposal_len
(
fake_sequence_group_metadata
,
proposal_lens
,
proposal_lens
)
select_proposal_len_zero
=
True
)
assert
filtered_groups
==
[]
assert
filtered_groups
==
[]
assert
indices
==
[]
assert
indices
==
[]
...
...
vllm/spec_decode/batch_expansion.py
View file @
1856aff4
...
@@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
...
@@ -10,8 +10,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
get_all_seq_ids
)
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.worker.worker_base
import
WorkerBase
SeqId
=
int
SeqId
=
int
...
@@ -88,17 +87,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -88,17 +87,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
target_sampler_output
=
target_sampler_output
[
0
]
(
all_tokens
,
all_probs
,
spec_logprobs
,
if
not
non_spec_indices
:
all_hidden_states
)
=
self
.
_contract_batch
(
# All sequence groups in batch have spec decoding enabled
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
contracted
=
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
)
non_spec_indices
=
non_spec_indices
,
else
:
spec_indices
=
spec_indices
,
# Batch has a mix of spec decode enabled and disabled seq groups
k
=
execute_model_req
.
num_lookahead_slots
,
contracted
=
self
.
_contract_batch
(
)
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
all_tokens
,
all_probs
,
spec_logprobs
,
all_hidden_states
=
contracted
return
SpeculativeScores
(
return
SpeculativeScores
(
probs
=
all_probs
,
probs
=
all_probs
,
token_ids
=
all_tokens
,
token_ids
=
all_tokens
,
...
@@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -121,14 +128,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
# done by supporting per-sequence proposal lens.
spec_seqs
,
spec_indices
=
split_batch_by_proposal_len
(
(
spec_seqs
,
spec_indices
),
(
non_spec_seqs
,
non_spec_indices
)
=
\
seq_group_metadata_list
,
split_batch_by_proposal_len
(
proposal_lens_list
,
seq_group_metadata_list
,
proposal_lens_list
)
select_proposal_len_zero
=
False
)
non_spec_seqs
,
non_spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
True
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
seq_group_metadata_list
=
spec_seqs
,
seq_group_metadata_list
=
spec_seqs
,
...
@@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -171,7 +173,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
# non-speculative sequences.
non_spec_expanded_bs
,
_
=
non_spec_target_token_ids
.
shape
non_spec_expanded_bs
=
len
(
non_spec_target_token_ids
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
...
@@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -181,7 +183,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if
target_hidden_states
is
not
None
:
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
target_hidden_states
=
target_hidden_states
.
reshape
(
spec_expanded_bs
,
k
+
1
,
target_hidden_states
.
shape
[
-
1
])
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
)
fill_value
=-
1
)
...
@@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -196,24 +198,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states
=
None
all_hidden_states
=
None
if
non_spec_indices
:
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_tokens
[
non_spec_indices
,
:
1
]
=
\
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
non_spec_target_token_ids
.
unsqueeze
(
1
)
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_probs
.
unsqueeze
(
1
)
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_logprobs
.
unsqueeze
(
1
)
if
all_hidden_states
is
not
None
:
if
all_hidden_states
is
not
None
:
all_hidden_states
[
assert
non_spec_target_hidden_states
is
not
None
non_spec_indices
,
:
1
,
:]
=
non_spec_target_hidden_states
all_hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_hidden_states
.
unsqueeze
(
1
)
if
spec_indices
:
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
all_logprobs
[
spec_indices
]
=
target_logprobs
if
all_hidden_states
is
not
None
:
if
all_hidden_states
is
not
None
:
all_hidden_states
[
spec_indices
]
=
target_hidden_states
all_hidden_states
[
spec_indices
]
=
target_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
def
_contract_batch_all_spec
(
self
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs
,
k
=
proposals
.
proposal_token_ids
.
shape
# Reshape tensors to original batch size
target_token_ids
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
contracted_bs
,
k
+
1
)
target_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
target_probs
.
shape
)
target_hidden_states
=
target_sampler_output
.
hidden_states
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
def
_create_scoring_model_input
(
def
_create_scoring_model_input
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -345,8 +381,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size
=
1
,
token_chunk_size
=
1
,
)
)
@
staticmethod
def
_split_scoring_output
(
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
@@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -361,10 +398,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
#
#
# First samples are from speculative scoring, latter samples are non-
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
# speculative samples.
split_sizes
=
[
split_sizes
=
(
num_scoring_tokens
,
num_scoring_tokens
,
sampler_output
.
sampled_token_ids
.
numel
()
-
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
num_scoring_tokens
)
]
(
spec_probs
,
non_spec_probs
(
spec_probs
,
non_spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
(
spec_sampled_tokens
,
non_spec_sampled_tokens
...
@@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -382,32 +418,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
else
:
spec_hidden_states
,
non_spec_hidden_states
=
None
,
None
spec_hidden_states
,
non_spec_hidden_states
=
None
,
None
# Convert scores to tensors.
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
sampler_output
.
sampled_token_probs
=
spec_probs
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
non_spec_logprobs
,
non_spec_hidden_states
)
sampler_output
.
logprobs
=
spec_logprobs
sampler_output
.
hidden_states
=
spec_hidden_states
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
sampler_output
.
logprobs
=
non_spec_logprobs
sampler_output
.
hidden_states
=
non_spec_hidden_states
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
=
sampler_output_to_torch
(
[
sampler_output
],
True
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
@
staticmethod
def
_create_target_seq_id_iterator
(
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
"""Create an iterator for creating target sequence ids.
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
distinct target sequence id for each proposal token to be scored.
...
@@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -417,8 +434,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""
"""
return
count
(
start
=
max
(
seq_ids
)
+
1
)
return
count
(
start
=
max
(
seq_ids
)
+
1
)
@
staticmethod
def
_get_token_ids_to_score
(
def
_get_token_ids_to_score
(
self
,
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
)
->
List
[
List
[
TokenId
]]:
)
->
List
[
List
[
TokenId
]]:
"""Given an int tensor of proposal token ids, return a list of
"""Given an int tensor of proposal token ids, return a list of
...
@@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -439,8 +456,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids
:
List
[
TokenId
]
=
[]
empty_token_ids
:
List
[
TokenId
]
=
[]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
.
extend
([
token_ids_to_score
.
extend
(
full_spec_token_ids
[:
i
+
1
]
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
)))
for
i
in
range
(
len
(
full_spec_token_ids
))
])
return
token_ids_to_score
return
token_ids_to_score
vllm/spec_decode/spec_decode_worker.py
View file @
1856aff4
...
@@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -365,12 +365,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# used during the prefill phase.
# used during the prefill phase.
# 2. Auto-disable enabled: The running queue size exceeds
# 2. Auto-disable enabled: The running queue size exceeds
# the specified threshold.
# the specified threshold.
# 3. No request: There are no requests in the batch.
# 3. No request: There are no requests in the batch, or
# none of the requests in the batch have spec decoding enabled.
# In any of these cases, the proposer and scorer workers
# In any of these cases, the proposer and scorer workers
# are called normally.
# are called normally.
no_spec
=
num_lookahead_slots
==
0
or
len
(
no_spec
=
num_lookahead_slots
==
0
or
disable_all_speculation
or
all
(
execute_model_req
.
seq_group_metadata_list
sgm
.
num_speculative_tokens
==
0
)
==
0
or
disable_all_speculation
for
sgm
in
execute_model_req
.
seq_group_metadata_list
)
# Broadcast how many lookahead slots are scheduled for this step, and
# Broadcast how many lookahead slots are scheduled for this step, and
# whether all speculation is disabled, to all non-driver workers.
# whether all speculation is disabled, to all non-driver workers.
...
@@ -415,10 +416,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -415,10 +416,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
,
execute_model_req
:
ExecuteModelRequest
)
->
bool
:
self
,
execute_model_req
:
ExecuteModelRequest
)
->
bool
:
# When the batch size is too large, disable speculative decoding
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
# to stop trading off throughput for latency.
disable_all_speculation
=
(
execute_model_req
.
running_queue_size
>=
return
(
execute_model_req
.
running_queue_size
>=
self
.
disable_by_batch_size
)
self
.
disable_by_batch_size
)
return
disable_all_speculation
def
_maybe_disable_speculative_tokens
(
def
_maybe_disable_speculative_tokens
(
self
,
disable_all_speculation
:
bool
,
self
,
disable_all_speculation
:
bool
,
...
@@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -621,14 +620,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# proposal len. This adds some complexity (splitting the batch into spec
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
# done by supporting per-sequence proposal lens.
_
,
spec_indices
=
split_batch_by_proposal_len
(
(
_
,
spec_indices
),
(
_
,
non_spec_indices
)
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
seq_group_metadata_list
,
proposal_lens_list
)
proposal_lens_list
,
select_proposal_len_zero
=
False
)
_
,
non_spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
True
)
original_indices
=
spec_indices
+
non_spec_indices
original_indices
=
spec_indices
+
non_spec_indices
# Get probabilities of target model, excluding bonus token.
# Get probabilities of target model, excluding bonus token.
...
...
vllm/spec_decode/top1_proposer.py
View file @
1856aff4
...
@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
...
@@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
# Currently only proposal lens of 0 or the global batch proposal len
# Currently only proposal lens of 0 or the global batch proposal len
# are supported.
# are supported.
# If max_proposal_len is defined, then we shall no exceed this
# If max_proposal_len is defined, then we shall no
t
exceed this
# quota for nonzero_proposal
# quota for nonzero_proposal
new_k
=
0
new_k
=
0
if
(
self
.
max_proposal_len
is
None
if
(
self
.
max_proposal_len
is
None
...
...
vllm/spec_decode/util.py
View file @
1856aff4
import
time
import
time
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch
...
@@ -98,33 +98,26 @@ def create_sequence_group_output(
...
@@ -98,33 +98,26 @@ def create_sequence_group_output(
def
split_batch_by_proposal_len
(
def
split_batch_by_proposal_len
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
proposal_lens
:
List
[
int
],
select_proposal_len_zero
:
bool
proposal_lens
:
List
[
int
],
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]:
)
->
Tuple
[
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]],
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]]:
"""Utility function that splits a batch based on whether the proposal len is
"""Utility function that splits a batch based on whether the proposal len is
zero or not. We should remove this once vLLM supports per-sequence proposal
zero or not. We should remove this once vLLM supports per-sequence proposal
lens in a batch.
lens in a batch.
"""
"""
if
select_proposal_len_zero
:
nonzero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
predicate
=
lambda
proposal_len
:
proposal_len
==
0
zero_lists
:
Tuple
[
List
[
SequenceGroupMetadata
],
List
[
int
]]
=
([],
[])
else
:
for
i
,
(
seq_group
,
proposal_len
)
in
enumerate
(
predicate
=
lambda
proposal_len
:
proposal_len
!=
0
zip
(
seq_group_metadata_list
,
proposal_lens
)):
seq_groups
,
indices
=
nonzero_lists
if
proposal_len
else
zero_lists
indices
=
[
seq_groups
.
append
(
seq_group
)
i
for
i
,
(
_
,
proposal_len
indices
.
append
(
i
)
)
in
enumerate
(
zip
(
seq_group_metadata_list
,
proposal_lens
))
return
nonzero_lists
,
zero_lists
if
predicate
(
proposal_len
)
]
seq_groups
=
[
seq_group
for
seq_group
,
proposal_len
in
zip
(
seq_group_metadata_list
,
proposal_lens
)
if
predicate
(
proposal_len
)
]
return
seq_groups
,
indices
def
sampler_output_to_torch
(
def
sampler_output_to_torch
(
sampler_output_list
:
List
[
SamplerOutput
],
sampler_transposed
:
bool
sampler_output_list
:
Sequence
[
SamplerOutput
],
sampler_transposed
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Utility function which converts a list of SamplerOutput to tensors.
"""Utility function which converts a list of SamplerOutput to tensors.
...
@@ -148,18 +141,12 @@ def sampler_output_to_torch(
...
@@ -148,18 +141,12 @@ def sampler_output_to_torch(
dim
=
0
,
dim
=
0
,
)
)
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output, vocab_size]
# shape: [batch_size, num_sampler_output, vocab_size]
sampled_token_logprobs
=
torch
.
stack
(
sampled_token_logprobs
=
torch
.
stack
(
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
[
sampler_output
.
logprobs
for
sampler_output
in
sampler_output_list
],
dim
=
0
,
dim
=
0
,
)
)
if
sampler_transposed
:
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
# shape: [batch_size, num_sampler_output]
# shape: [batch_size, num_sampler_output]
sampled_token_ids
=
torch
.
stack
(
sampled_token_ids
=
torch
.
stack
(
[
[
...
@@ -168,7 +155,10 @@ def sampler_output_to_torch(
...
@@ -168,7 +155,10 @@ def sampler_output_to_torch(
],
],
dim
=
0
,
dim
=
0
,
)
)
if
sampler_transposed
:
if
sampler_transposed
:
sampled_token_probs
=
sampled_token_probs
.
transpose
(
0
,
1
)
sampled_token_logprobs
=
sampled_token_logprobs
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
sampled_token_ids
=
sampled_token_ids
.
transpose
(
0
,
1
)
if
sampler_output_list
[
0
].
hidden_states
is
not
None
:
if
sampler_output_list
[
0
].
hidden_states
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment