Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
89feb4c8
Unverified
Commit
89feb4c8
authored
Oct 11, 2024
by
Lily Liu
Committed by
GitHub
Oct 12, 2024
Browse files
[SpecDec] Remove Batch Expansion (2/3) (#9298)
parent
ec10cb85
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
122 additions
and
70 deletions
+122
-70
tests/spec_decode/test_scorer.py
tests/spec_decode/test_scorer.py
+39
-13
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+2
-5
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+42
-27
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-5
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+1
-1
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+2
-5
vllm/spec_decode/mqa_scorer.py
vllm/spec_decode/mqa_scorer.py
+34
-8
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+0
-6
No files found.
tests/spec_decode/test_scorer.py
View file @
89feb4c8
import
random
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -10,31 +13,45 @@ from vllm.worker.worker import Worker
...
@@ -10,31 +13,45 @@ from vllm.worker.worker import Worker
from
.utils
import
create_batch
,
create_worker
from
.utils
import
create_batch
,
create_worker
def
create_proposal
(
batch_size
:
int
,
propose_len
:
int
,
vocab_size
:
int
,
def
create_proposal
(
propose_len
s
:
List
[
int
]
,
vocab_size
:
int
,
device
:
str
)
->
SpeculativeProposals
:
device
:
str
)
->
SpeculativeProposals
:
proposal_probs
=
torch
.
rand
((
batch_size
,
propose_len
,
vocab_size
),
batch_size
=
len
(
propose_lens
)
max_propose_len
=
max
(
propose_lens
)
proposal_probs
=
torch
.
rand
((
batch_size
,
max_propose_len
,
vocab_size
),
device
=
device
)
proposal_token_ids
=
torch
.
full
((
batch_size
,
max_propose_len
),
fill_value
=-
1
,
device
=
device
)
device
=
device
)
proposal_token_ids
=
torch
.
argmax
(
proposal_probs
,
dim
=-
1
)
for
i
in
range
(
batch_size
):
proposal_lens
=
torch
.
tensor
([
propose_len
]
*
batch_size
,
device
=
device
)
proposal_token_ids
[
i
][:
propose_lens
[
i
]]
=
torch
.
argmax
(
proposal_probs
[
i
][:
propose_lens
[
i
]],
dim
=-
1
)
propose_lens
=
torch
.
tensor
(
propose_lens
,
device
=
device
)
return
SpeculativeProposals
(
proposal_token_ids
,
proposal_probs
,
return
SpeculativeProposals
(
proposal_token_ids
,
proposal_probs
,
propos
al
_lens
)
propos
e
_lens
)
def
assert_score_equal
(
score1
:
SpeculativeScores
,
def
assert_score_equal
(
score1
:
SpeculativeScores
,
score2
:
SpeculativeScores
)
->
None
:
score2
:
SpeculativeScores
)
->
None
:
assert
torch
.
allclose
(
score1
.
probs
,
score2
.
probs
)
assert
torch
.
allclose
(
score1
.
probs
,
score2
.
probs
)
assert
torch
.
allclose
(
score1
.
logprobs
,
score2
.
logprobs
)
assert
torch
.
allclose
(
score1
.
logprobs
,
score2
.
logprobs
)
assert
torch
.
equal
(
score1
.
token_ids
,
score2
.
token_ids
)
assert
torch
.
equal
(
score1
.
token_ids
,
score2
.
token_ids
),
f
"
{
score1
.
token_ids
}
,
{
score2
.
token_ids
}
"
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
'facebook/opt-125m'
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
'facebook/opt-125m'
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
'batch_size'
,
[
1
,
2
,
4
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
'propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'max_propose_len'
,
[
1
,
3
,
5
])
@
pytest
.
mark
.
parametrize
(
'mixed_propose_len'
,
[
True
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
@
pytest
.
mark
.
parametrize
(
'device'
,
[
'cuda'
])
def
test_scor
o
er
(
model_name
:
str
,
batch_size
:
int
,
propose_len
:
int
,
def
test_scorer
(
model_name
:
str
,
batch_size
:
int
,
max_
propose_len
:
int
,
device
:
str
)
->
None
:
mixed_propose_len
:
bool
,
device
:
str
)
->
None
:
"""
"""
Compare the batch expansion scorer and mqa scorer return the same score
Compare the batch expansion scorer and mqa scorer return the same score.
We test for both queries with the same propose length and different
propose length.
"""
"""
seed
=
0
seed
=
0
block_size
=
32
block_size
=
32
...
@@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
...
@@ -46,13 +63,22 @@ def test_scoroer(model_name: str, batch_size: int, propose_len: int,
should_modify_greedy_probs_inplace
=
True
should_modify_greedy_probs_inplace
=
True
vocab_size
=
scorer_worker
.
vocab_size
vocab_size
=
scorer_worker
.
vocab_size
proposals
=
create_proposal
(
batch_size
,
propose_len
,
vocab_size
,
device
)
if
not
mixed_propose_len
:
propose_lens
=
[
max_propose_len
]
*
batch_size
else
:
non_zero_cnt
=
random
.
randint
(
0
,
batch_size
)
propose_lens
=
[
max_propose_len
]
*
non_zero_cnt
+
[
0
]
*
(
batch_size
-
non_zero_cnt
)
random
.
shuffle
(
propose_lens
)
proposals
=
create_proposal
(
propose_lens
,
vocab_size
,
device
)
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
seq_group_metadatalist
,
_
,
_
=
create_batch
(
batch_size
,
propose_len
,
max_
propose_len
,
block_size
=
block_size
,
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
)
num_gpu_blocks
=
num_gpu_blocks
)
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
requests
=
ExecuteModelRequest
(
seq_group_metadatalist
,
num_lookahead_slots
=
propose_len
)
num_lookahead_slots
=
max_
propose_len
)
batch_expansion_scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
,
device
,
batch_expansion_scorer
=
BatchExpansionTop1Scorer
(
scorer_worker
,
device
,
vocab_size
)
vocab_size
)
...
...
vllm/attention/backends/blocksparse_attn.py
View file @
89feb4c8
...
@@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...
@@ -186,11 +186,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
use_cuda_graph
:
bool
# Number of query tokens for each request in the batch.
# Max number of query tokens for among request in the batch.
# Currently, we require that all requests have the same number of query
max_decode_query_len
:
Optional
[
int
]
=
None
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
_cached_prefill_metadata
:
Optional
[
"BlocksparseFlashAttentionMetadata"
]
=
None
"BlocksparseFlashAttentionMetadata"
]
=
None
...
...
vllm/attention/backends/flash_attn.py
View file @
89feb4c8
...
@@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
# Maximum query length in the batch.
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
max_query_len
:
Optional
[
int
]
# Number of query tokens for each request in the batch.
# Max number of query tokens among request in the batch.
# Currently, we require that all requests have the same number of query
max_decode_query_len
:
Optional
[
int
]
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
# requests only.
...
@@ -173,9 +170,9 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -173,9 +170,9 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
decode_query_len
=
0
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_query_len
=
0
,
max_decode_seq_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
...
@@ -202,12 +199,14 @@ class FlashAttentionMetadata(AttentionMetadata):
...
@@ -202,12 +199,14 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
decode_query_len
=
self
.
decode_query_len
,
max_
decode_query_len
=
self
.
max_
decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
query_start_loc
=
self
.
query_start_loc
[
self
.
num_prefills
:]
seq_start_loc
=
None
,
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
use_cuda_graph
=
self
.
use_cuda_graph
,
...
@@ -413,9 +412,9 @@ class FlashAttentionMetadataBuilder(
...
@@ -413,9 +412,9 @@ class FlashAttentionMetadataBuilder(
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
if
len
(
decode_query_lens
)
>
0
:
decode_query_len
=
max
(
decode_query_lens
)
max_
decode_query_len
=
max
(
decode_query_lens
)
else
:
else
:
decode_query_len
=
1
max_
decode_query_len
=
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
num_decode_tokens
=
self
.
num_decode_tokens
...
@@ -468,7 +467,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -468,7 +467,7 @@ class FlashAttentionMetadataBuilder(
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
decode_query_len
=
decode_query_len
,
max_
decode_query_len
=
max_
decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
...
@@ -714,11 +713,28 @@ def unified_flash_attention(
...
@@ -714,11 +713,28 @@ def unified_flash_attention(
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Decoding run.
_
,
num_head
,
head_dim
=
decode_query
.
shape
# Use flash_attn_varlen_func kernel for speculative decoding
decode_query
=
decode_query
.
reshape
(
-
1
,
decode_meta
.
decode_query_len
,
# because different queries might have different lengths.
num_head
,
head_dim
)
assert
decode_meta
.
max_decode_query_len
is
not
None
decode_output
=
flash_attn_with_kvcache
(
if
decode_meta
.
max_decode_query_len
>
1
:
decode_output
=
flash_attn_varlen_func
(
q
=
decode_query
,
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
cu_seqlens_k
=
decode_meta
.
seq_start_loc
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
decode_output
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
v_cache
=
value_cache
,
block_table
=
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
...
@@ -739,7 +755,6 @@ def unified_flash_attention(
...
@@ -739,7 +755,6 @@ def unified_flash_attention(
# Chunked prefill does not work with speculative decoding.
# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert
decode_meta
is
not
None
assert
decode_meta
is
not
None
assert
decode_meta
.
decode_query_len
==
1
decode_output
=
decode_output
.
squeeze
(
1
)
decode_output
=
decode_output
.
squeeze
(
1
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
output
=
torch
.
cat
([
prefill_output
,
decode_output
],
dim
=
0
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
89feb4c8
...
@@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -121,11 +121,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
# so far).
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Number of query tokens for each request in the batch.
# Max number of query tokens among request in the batch.
# Currently, we require that all requests have the same number of query
max_decode_query_len
:
Optional
[
int
]
=
None
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
...
...
vllm/attention/backends/utils.py
View file @
89feb4c8
...
@@ -313,7 +313,7 @@ class CommonAttentionState(AttentionState):
...
@@ -313,7 +313,7 @@ class CommonAttentionState(AttentionState):
seq_lens
=
None
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
1
,
max_query_len
=
1
,
decode_query_len
=
1
,
max_
decode_query_len
=
1
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
query_start_loc
=
None
,
query_start_loc
=
None
,
...
...
vllm/attention/backends/xformers.py
View file @
89feb4c8
...
@@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
...
@@ -118,11 +118,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Maximum query length in the batch. None for decoding.
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
=
None
max_query_len
:
Optional
[
int
]
=
None
# Number of query tokens for each request in the batch.
# Max number of query tokens among request in the batch.
# Currently, we require that all requests have the same number of query
max_decode_query_len
:
Optional
[
int
]
=
None
# tokens during the decoding phase. When speculavie decoding is enabled,
# decode_query_len might be greater than 1. In all other cases, it is 1.
decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# the batch, used to index into subquery. E.g., if the subquery length
...
...
vllm/spec_decode/mqa_scorer.py
View file @
89feb4c8
...
@@ -18,6 +18,7 @@ class MQAScorer(SpeculativeScorer):
...
@@ -18,6 +18,7 @@ class MQAScorer(SpeculativeScorer):
target_seq_id_start
=
max
(
target_seq_id_start
=
max
(
get_all_seq_ids
(
execute_model_req
.
seq_group_metadata_list
))
+
1
get_all_seq_ids
(
execute_model_req
.
seq_group_metadata_list
))
+
1
all_proposal_tokens
=
proposals
.
proposal_token_ids
.
tolist
()
all_proposal_tokens
=
proposals
.
proposal_token_ids
.
tolist
()
all_proposal_lengths
=
proposals
.
proposal_lens
.
tolist
()
for
i
,
seq_group_metadata
in
enumerate
(
for
i
,
seq_group_metadata
in
enumerate
(
execute_model_req
.
seq_group_metadata_list
):
execute_model_req
.
seq_group_metadata_list
):
seq_data_dict
=
seq_group_metadata
.
seq_data
seq_data_dict
=
seq_group_metadata
.
seq_data
...
@@ -27,7 +28,8 @@ class MQAScorer(SpeculativeScorer):
...
@@ -27,7 +28,8 @@ class MQAScorer(SpeculativeScorer):
seq_data
:
SequenceData
=
seq_data_dict
[
seq_id
]
seq_data
:
SequenceData
=
seq_data_dict
[
seq_id
]
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
prompt_token_ids
=
seq_data
.
get_prompt_token_ids
()
output_token_ids
=
seq_data
.
get_output_token_ids
()
output_token_ids
=
seq_data
.
get_output_token_ids
()
proposal_token_ids
=
all_proposal_tokens
[
i
]
proposal_token_ids
=
all_proposal_tokens
[
i
][:
all_proposal_lengths
[
i
]]
new_output_token_ids
=
[
*
output_token_ids
,
*
proposal_token_ids
]
new_output_token_ids
=
[
*
output_token_ids
,
*
proposal_token_ids
]
target_seq_id
=
target_seq_id_start
+
i
target_seq_id
=
target_seq_id_start
+
i
...
@@ -62,18 +64,42 @@ class MQAScorer(SpeculativeScorer):
...
@@ -62,18 +64,42 @@ class MQAScorer(SpeculativeScorer):
target_sampler_output
=
target_sampler_output
[
0
]
target_sampler_output
=
target_sampler_output
[
0
]
bs
,
k
=
proposals
.
proposal_token_ids
.
shape
k
=
execute_model_req
.
num_lookahead_slots
all_tokens
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
bs
,
k
+
1
)
bs
=
len
(
execute_model_req
.
seq_group_metadata_list
)
target_token_ids
=
target_sampler_output
.
sampled_token_ids
all_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
target_probs
=
target_sampler_output
.
sampled_token_probs
bs
,
k
+
1
,
self
.
_vocab_size
)
target_logprobs
=
target_sampler_output
.
logprobs
all_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
# If all requests have the same number of query tokens, we can avoid
bs
,
k
+
1
,
self
.
_vocab_size
)
# the for loop to build output for better performance.
if
min
(
all_proposal_lengths
)
==
k
:
bs
,
_
=
proposals
.
proposal_token_ids
.
shape
all_tokens
=
target_token_ids
.
reshape
(
bs
,
k
+
1
)
all_probs
=
target_probs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
reshape
(
bs
,
k
+
1
,
self
.
_vocab_size
)
else
:
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
bs
,
k
+
1
),
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
fill_value
=-
float
(
"inf"
))
target_token_ids
=
target_token_ids
.
flatten
()
start_loc
=
0
for
i
,
proposed_len
in
enumerate
(
all_proposal_lengths
):
output_len
=
proposed_len
+
1
end_loc
=
start_loc
+
output_len
all_tokens
[
i
,
:
output_len
]
=
target_token_ids
[
start_loc
:
end_loc
]
all_probs
[
i
,
:
output_len
]
=
target_probs
[
start_loc
:
end_loc
]
all_logprobs
[
i
,
:
output_len
]
=
target_logprobs
[
start_loc
:
end_loc
]
start_loc
=
end_loc
hidden_states
=
None
hidden_states
=
None
if
target_sampler_output
.
hidden_states
is
not
None
:
if
target_sampler_output
.
hidden_states
is
not
None
:
hidden_states
=
target_sampler_output
.
hidden_states
.
reshape
(
hidden_states
=
target_sampler_output
.
hidden_states
.
reshape
(
bs
,
(
k
+
1
),
-
1
)
bs
,
(
k
+
1
),
-
1
)
return
SpeculativeScores
(
probs
=
all_probs
,
return
SpeculativeScores
(
probs
=
all_probs
,
token_ids
=
all_tokens
,
token_ids
=
all_tokens
,
logprobs
=
all_logprobs
,
logprobs
=
all_logprobs
,
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
89feb4c8
...
@@ -190,12 +190,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -190,12 +190,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"[Speculative Decoding] Disabling MQA scorer as the "
"[Speculative Decoding] Disabling MQA scorer as the "
"MQA is only available with flash attn backend."
)
"MQA is only available with flash attn backend."
)
if
ngram_prompt_lookup_max
>
0
:
disable_mqa_scorer
=
True
logger
.
info
(
"[Speculative Decoding] Disabling MQA scorer as the "
"NGramWorker does not support MQA scorer."
)
if
"model_config"
in
draft_worker_kwargs
and
\
if
"model_config"
in
draft_worker_kwargs
and
\
draft_worker_kwargs
[
"model_config"
].
max_model_len
<
\
draft_worker_kwargs
[
"model_config"
].
max_model_len
<
\
scorer_worker
.
model_config
.
max_model_len
:
scorer_worker
.
model_config
.
max_model_len
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment