Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a0086298
Unverified
Commit
a0086298
authored
Jun 10, 2024
by
Nick Hill
Committed by
GitHub
Jun 11, 2024
Browse files
[Misc] Various simplifications and typing fixes (#5368)
parent
76477a93
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
63 additions
and
90 deletions
+63
-90
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+1
-1
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+4
-2
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+12
-27
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+7
-7
vllm/spec_decode/top1_proposer.py
vllm/spec_decode/top1_proposer.py
+18
-27
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+3
-8
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+3
-3
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+15
-15
No files found.
vllm/engine/output_processor/multi_step.py
View file @
a0086298
...
...
@@ -78,7 +78,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group, we can take the
# first sample.
samples
=
[
output
s
[
step
]
.
samples
[
0
]
for
step
in
range
(
len
(
outputs
))
]
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
# -1 means the output token is not valid (eg. due to spec decode
# rejecting tokens).
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
a0086298
...
...
@@ -306,8 +306,10 @@ class RejectionSampler(nn.Module):
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
),
out
=
output
)
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
...
...
vllm/spec_decode/batch_expansion.py
View file @
a0086298
...
...
@@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output
=
self
.
_scorer_worker
.
execute_model
(
execute_model_req
=
execute_model_req
.
clone
(
seq_group_metadata_list
=
target_seq_group_metadata_list
,
))
seq_group_metadata_list
=
target_seq_group_metadata_list
))
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
...
...
@@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens
)
def
_contract_batch
(
self
,
contracted_bs
:
int
,
target_sampler_output
:
List
[
SamplerOutput
],
self
,
contracted_bs
:
int
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
num_scoring_tokens
:
int
,
non_spec_indices
:
List
[
int
],
spec_indices
:
List
[
int
],
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_expanded_bs
,
_
=
non_spec_target_token_ids
.
shape
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
squeeze
().
reshape
(
spec_expanded_bs
,
k
+
1
,
self
.
_vocab_size
)
all_tokens
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
,
device
=
self
.
_device
,
dtype
=
torch
.
long
)
all_probs
=
torch
.
zeros
(
contracted_bs
,
k
+
1
,
self
.
_vocab_size
,
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
all_logprobs
=
torch
.
full
(
size
=
(
contracted_bs
,
k
+
1
,
self
.
_vocab_size
,
),
fill_value
=-
float
(
"inf"
),
device
=
self
.
_device
,
dtype
=
torch
.
float32
)
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
target_probs
=
target_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_logprobs
.
reshape
(
target_probs
.
shape
)
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
)
all_probs
=
target_probs
.
new_zeros
(
*
all_tokens
.
shape
,
self
.
_vocab_size
)
all_logprobs
=
target_logprobs
.
new_full
(
size
=
all_probs
.
shape
,
fill_value
=-
float
(
"inf"
))
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
a0086298
...
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
vllm.config
import
SpeculativeConfig
from
vllm.distributed.communication_op
import
broadcast_tensor_dict
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
...
...
@@ -30,7 +31,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert
"speculative_config"
in
kwargs
speculative_config
=
kwargs
.
get
(
"speculative_config"
)
speculative_config
:
SpeculativeConfig
=
kwargs
.
get
(
"speculative_config"
)
assert
speculative_config
is
not
None
target_worker
=
Worker
(
*
args
,
**
kwargs
)
...
...
@@ -109,12 +110,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
logger
.
info
(
"Configuring SpecDecodeWorker with proposer=%s"
,
type
(
proposer_worker
))
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
,
))
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
disable_by_batch_size
=
disable_by_batch_size
,
rejection_sampler
=
RejectionSampler
(
disable_bonus_tokens
=
disable_bonus_tokens
))
def
__init__
(
self
,
...
...
vllm/spec_decode/top1_proposer.py
View file @
a0086298
...
...
@@ -148,7 +148,8 @@ class Top1Proposer(SpeculativeProposer):
nonzero_proposal_len_indices
,
)
def
_remove_no_proposal_seqs
(
self
,
proposal_lens
,
maybe_sampler_output
,
@
staticmethod
def
_remove_no_proposal_seqs
(
proposal_lens
,
maybe_sampler_output
,
nonzero_proposal_len_indices
,
transposed
):
"""Remove sequences from nonzero_proposal_len_indices and reset
their proposal_len to 0 the draft worker does not provide a proposal
...
...
@@ -207,7 +208,7 @@ class Top1Proposer(SpeculativeProposer):
self
,
batch_size
:
int
,
proposal_len
:
int
,
maybe_sampler_output
:
Optional
[
SamplerOutput
],
maybe_sampler_output
:
Optional
[
List
[
SamplerOutput
]
]
,
proposal_lens
:
List
[
int
],
nonzero_proposal_len_indices
:
List
[
int
],
sampler_transposed
:
bool
,
...
...
@@ -218,25 +219,19 @@ class Top1Proposer(SpeculativeProposer):
if
maybe_sampler_output
is
None
:
# If no speculative tokens, the sampler output will be None.
# In this case we return empty proposals.
proposal_tokens
=
torch
.
full
(
size
=
(
batch_size
,
proposal_len
,
),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
proposal_probs
=
torch
.
zeros
(
batch_size
,
proposal_len
,
self
.
_vocab_size
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
proposal_lens_tensor
=
torch
.
zeros
(
len
(
proposal_lens
),
dtype
=
torch
.
long
,
device
=
self
.
_device
)
proposal_tokens
=
torch
.
tensor
(
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
)
proposal_probs
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float32
,
device
=
self
.
_device
).
expand
(
batch_size
,
proposal_len
,
self
.
_vocab_size
)
proposal_lens_tensor
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
self
.
_device
).
expand
(
len
(
proposal_lens
))
return
proposal_tokens
,
proposal_probs
,
proposal_lens_tensor
sampler_output
=
maybe_sampler_output
...
...
@@ -246,18 +241,14 @@ class Top1Proposer(SpeculativeProposer):
# Now, reformat the output GPU tensors such that each sequence has
# a proposal. the proposal can be empty, e.g. [-1, -1, -1]
entire_proposal_tokens
=
torch
.
full
(
entire_proposal_tokens
=
proposal_tokens
.
new_
full
(
size
=
(
batch_size
,
*
proposal_tokens
.
shape
[
1
:]),
fill_value
=-
1
,
dtype
=
torch
.
long
,
device
=
self
.
_device
,
)
entire_proposal_tokens
[
nonzero_proposal_len_indices
]
=
proposal_tokens
entire_proposal_probs
=
torch
.
zeros
(
entire_proposal_probs
=
proposal_probs
.
new_
zeros
(
batch_size
,
*
proposal_probs
.
shape
[
1
:],
dtype
=
torch
.
float32
,
device
=
self
.
_device
,
)
entire_proposal_probs
[
nonzero_proposal_len_indices
]
=
proposal_probs
...
...
vllm/spec_decode/util.py
View file @
a0086298
from
contextlib
import
contextmanager
from
itertools
import
chain
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceOutput
)
SeqId
=
int
...
...
@@ -16,11 +15,7 @@ def get_all_seq_ids(
"""Given a list of SequenceGroupMetadata, create a list of all
sequence ids.
"""
return
list
(
chain
.
from_iterable
([
seq_group_metadata
.
seq_data
.
keys
()
for
seq_group_metadata
in
seq_group_metadata_list
]))
return
[
seq_id
for
sg
in
seq_group_metadata_list
for
seq_id
in
sg
.
seq_data
]
def
get_all_num_logprobs
(
...
...
@@ -68,7 +63,7 @@ def create_sequence_group_output(
seq_id
:
SeqId
,
topk_token_ids
:
List
[
int
],
topk_logprobs
:
List
[
float
],
)
->
SequenceGroupOutput
:
)
->
Completion
SequenceGroupOutput
:
"""Create a SequenceGroupOutput given the sampling results.
Args:
...
...
vllm/transformers_utils/config.py
View file @
a0086298
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
,
Type
from
transformers
import
PretrainedConfig
...
...
@@ -9,7 +9,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
PretrainedConfig
]
=
{
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]
]
=
{
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
...
...
@@ -68,4 +68,4 @@ def get_hf_text_config(config: PretrainedConfig):
assert
hasattr
(
config
.
text_config
,
"num_attention_heads"
)
return
config
.
text_config
else
:
return
config
\ No newline at end of file
return
config
vllm/worker/model_runner.py
View file @
a0086298
...
...
@@ -527,16 +527,6 @@ class ModelRunner:
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
...
...
@@ -544,11 +534,6 @@ class ModelRunner:
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
...
...
@@ -601,6 +586,21 @@ class ModelRunner:
seq_start_loc
=
seq_start_loc
,
data_type
=
kv_cache_dtype
)
else
:
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment