Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e95cd879
Unverified
Commit
e95cd879
authored
Apr 16, 2024
by
Cade Daniel
Committed by
GitHub
Apr 16, 2024
Browse files
[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
parent
69e1d2fb
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
140 additions
and
33 deletions
+140
-33
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+4
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+2
-1
vllm/sequence.py
vllm/sequence.py
+13
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+18
-5
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+15
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+35
-9
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+26
-0
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+5
-3
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+7
-4
vllm/worker/worker.py
vllm/worker/worker.py
+8
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+7
-6
No files found.
vllm/executor/neuron_executor.py
View file @
e95cd879
...
...
@@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
assert
(
blocks_to_swap_in
==
{}
and
blocks_to_swap_out
==
{}
and
blocks_to_copy
==
{}),
(
"Cache operations are not supported for Neuron backend."
)
assert
num_lookahead_slots
==
0
,
(
"lookahead not supported for Neuron backend."
)
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
)
...
...
vllm/executor/ray_gpu_executor.py
View file @
e95cd879
...
...
@@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
=
0
)
->
SamplerOutput
:
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
...
...
vllm/sequence.py
View file @
e95cd879
...
...
@@ -693,3 +693,16 @@ class SamplerOutput:
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
def
__repr__
(
self
)
->
str
:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr
=
(
"None"
if
self
.
sampled_token_probs
is
None
else
self
.
sampled_token_probs
.
shape
)
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
self
.
sampled_token_ids
.
shape
)
return
(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
vllm/spec_decode/batch_expansion.py
View file @
e95cd879
...
...
@@ -6,10 +6,10 @@ import torch
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
maybe_mock_device_tensors
,
nvtx_range
,
sampler_output_to_torch
,
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
Worker
Base
SeqId
=
int
TargetSeqId
=
int
...
...
@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
"""
def
__init__
(
self
,
scorer_worker
:
Worker
,
device
:
str
,
vocab_size
:
int
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
...
...
@@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
return_python_output
=
False
)
)
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
original_bs
=
len
(
seq_group_metadata_list
),
...
...
@@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This maps the scores of speculative tokens back to their original
sequences.
"""
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors
(
sampler_output
=
target_sampler_output
,
batch_size
=
len
(
non_spec_indices
)
+
num_scoring_tokens
,
vocab_size
=
self
.
_vocab_size
,
device
=
self
.
_device
,
)
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
...
...
vllm/spec_decode/multi_step_worker.py
View file @
e95cd879
...
...
@@ -6,7 +6,8 @@ import torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.spec_decode.util
import
(
maybe_mock_device_tensors
,
sampler_output_to_torch
)
from
vllm.worker.worker
import
Worker
...
...
@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
...
...
@@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer):
sampler_output
=
maybe_sampler_output
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
for
step_output
in
sampler_output
:
maybe_mock_device_tensors
(
sampler_output
=
step_output
,
batch_size
=
len
(
proposal_lens
),
vocab_size
=
self
.
_vocab_size
,
device
=
self
.
_device
,
)
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
sampler_output
)
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
e95cd879
...
...
@@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
...
...
@@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
logger
=
init_logger
(
__name__
)
class
SpecDecodeWorker
(
LoraNotSupportedWorkerBase
):
...
...
@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
@
classmethod
def
from_workers
(
cls
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
WorkerBase
)
->
"SpecDecodeWorker"
:
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
# TODO(cade) disable strict mode for speedup.
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
),
)
def
__init__
(
self
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
Worker
,
scorer_worker
:
Worker
Base
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
):
...
...
@@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
rejection_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
scorer
=
BatchExpansionTop1Scorer
(
...
...
@@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
num_
spec_token
s
:
int
,
num_
lookahead_slot
s
:
int
,
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""
...
...
@@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
logger
.
info
(
f
"spec_decode_worker.execute_model
{
num_lookahead_slots
=
}
"
)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if
num_
spec_token
s
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
if
num_
lookahead_slot
s
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
...
@@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
k
=
num_
spec_token
s
,
k
=
num_
lookahead_slot
s
,
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
...
...
@@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer and scorer model so that the KV cache is consistent between the
two.
"""
logger
.
info
(
"run proposer worker no spec"
)
self
.
proposer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
return_python_output
=
False
)
)
logger
.
info
(
"run target worker no spec"
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
...
...
@@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
logger
.
info
(
"get spec proposals"
)
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
logger
.
info
(
"score proposals"
)
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
...
...
@@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
)
logger
.
info
(
"verify proposals"
)
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
logger
.
info
(
"create output list"
)
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
k
)
...
...
@@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
# TODO Add verifier logprobs.
logprobs
=
{
token_id
:
0.0
},
logprobs
=
{
token_id
:
Logprob
(
0.0
)
},
)
],
prompt_logprobs
=
None
,
...
...
vllm/spec_decode/util.py
View file @
e95cd879
...
...
@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return
sampled_token_ids
,
sampled_token_probs
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
str
)
->
None
:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values
=
[
sampler_output
.
sampled_token_probs
,
sampler_output
.
sampled_token_ids
]
assert
all
(
v
is
None
for
v
in
values
)
or
not
any
(
v
is
None
for
v
in
values
)
if
not
any
(
v
is
None
for
v
in
values
):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output
.
sampled_token_probs
=
torch
.
nn
.
functional
.
softmax
(
torch
.
rand
(
batch_size
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
device
),
dim
=-
1
)
sampler_output
.
sampled_token_ids
=
torch
.
randint
(
low
=
10
,
high
=
100
,
size
=
(
batch_size
,
),
dtype
=
torch
.
long
,
device
=
device
)
@
contextmanager
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
"""
...
...
vllm/worker/cpu_worker.py
View file @
e95cd879
...
...
@@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
...
...
@@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
cpu_cache
)
return
output
# CPU worker only supports single-step execution.
return
[
output
]
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
...
...
vllm/worker/neuron_worker.py
View file @
e95cd879
"""A Neuron worker class."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch.distributed
...
...
@@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Optional
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
num_seq_groups
=
len
(
seq_group_metadata_list
)
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
)
return
output
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return
[
output
]
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Determine the size in bytes of a cache block.
...
...
vllm/worker/worker.py
View file @
e95cd879
...
...
@@ -210,7 +210,9 @@ class Worker(WorkerBase):
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
num_lookahead_slots
:
int
=
0
,
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
...
...
@@ -235,11 +237,14 @@ class Worker(WorkerBase):
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
gpu_cache
)
return
output
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return
[
output
]
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
...
...
vllm/worker/worker_base.py
View file @
e95cd879
...
...
@@ -40,12 +40,13 @@ class WorkerBase(ABC):
raise
NotImplementedError
@
abstractmethod
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
"""Executes one model step on the given sequences."""
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
List
[
SamplerOutput
]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise
NotImplementedError
@
abstractmethod
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment