Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e95cd879
Unverified
Commit
e95cd879
authored
Apr 16, 2024
by
Cade Daniel
Committed by
GitHub
Apr 16, 2024
Browse files
[Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894)
parent
69e1d2fb
Changes
31
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
140 additions
and
33 deletions
+140
-33
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+4
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+2
-1
vllm/sequence.py
vllm/sequence.py
+13
-0
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+18
-5
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+15
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+35
-9
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+26
-0
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+5
-3
vllm/worker/neuron_worker.py
vllm/worker/neuron_worker.py
+7
-4
vllm/worker/worker.py
vllm/worker/worker.py
+8
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+7
-6
No files found.
vllm/executor/neuron_executor.py
View file @
e95cd879
...
@@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase):
...
@@ -48,10 +48,13 @@ class NeuronExecutor(ExecutorBase):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
)
->
List
[
SamplerOutput
]:
assert
(
blocks_to_swap_in
==
{}
and
blocks_to_swap_out
==
{}
assert
(
blocks_to_swap_in
==
{}
and
blocks_to_swap_out
==
{}
and
blocks_to_copy
==
{}),
(
and
blocks_to_copy
==
{}),
(
"Cache operations are not supported for Neuron backend."
)
"Cache operations are not supported for Neuron backend."
)
assert
num_lookahead_slots
==
0
,
(
"lookahead not supported for Neuron backend."
)
output
=
self
.
driver_worker
.
execute_model
(
output
=
self
.
driver_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
)
seq_group_metadata_list
=
seq_group_metadata_list
)
...
...
vllm/executor/ray_gpu_executor.py
View file @
e95cd879
...
@@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase):
...
@@ -242,7 +242,8 @@ class RayGPUExecutor(ExecutorBase):
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
num_lookahead_slots
:
int
=
0
)
->
SamplerOutput
:
all_outputs
=
self
.
_run_workers
(
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
"execute_model"
,
driver_kwargs
=
{
driver_kwargs
=
{
...
...
vllm/sequence.py
View file @
e95cd879
...
@@ -693,3 +693,16 @@ class SamplerOutput:
...
@@ -693,3 +693,16 @@ class SamplerOutput:
def
__eq__
(
self
,
other
:
object
):
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
def
__repr__
(
self
)
->
str
:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr
=
(
"None"
if
self
.
sampled_token_probs
is
None
else
self
.
sampled_token_probs
.
shape
)
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
self
.
sampled_token_ids
.
shape
)
return
(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
vllm/spec_decode/batch_expansion.py
View file @
e95cd879
...
@@ -6,10 +6,10 @@ import torch
...
@@ -6,10 +6,10 @@ import torch
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
maybe_mock_device_tensors
,
sampler_output_to_torch
,
nvtx_range
,
sampler_output_to_torch
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
Worker
Base
SeqId
=
int
SeqId
=
int
TargetSeqId
=
int
TargetSeqId
=
int
...
@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -31,7 +31,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
of topk/tree.
of topk/tree.
"""
"""
def
__init__
(
self
,
scorer_worker
:
Worker
,
device
:
str
,
vocab_size
:
int
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
self
.
_scorer_worker
=
scorer_worker
self
.
_device
=
device
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
self
.
_vocab_size
=
vocab_size
...
@@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -83,7 +84,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
return_python_output
=
False
)
)
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
all_tokens
,
all_probs
=
self
.
_contract_batch
(
all_tokens
,
all_probs
=
self
.
_contract_batch
(
original_bs
=
len
(
seq_group_metadata_list
),
original_bs
=
len
(
seq_group_metadata_list
),
...
@@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -142,6 +145,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
This maps the scores of speculative tokens back to their original
This maps the scores of speculative tokens back to their original
sequences.
sequences.
"""
"""
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
maybe_mock_device_tensors
(
sampler_output
=
target_sampler_output
,
batch_size
=
len
(
non_spec_indices
)
+
num_scoring_tokens
,
vocab_size
=
self
.
_vocab_size
,
device
=
self
.
_device
,
)
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
(
target_token_ids
,
target_probs
,
non_spec_target_token_ids
,
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
non_spec_target_probs
)
=
self
.
_split_scoring_output
(
target_sampler_output
,
num_scoring_tokens
)
target_sampler_output
,
num_scoring_tokens
)
...
...
vllm/spec_decode/multi_step_worker.py
View file @
e95cd879
...
@@ -6,7 +6,8 @@ import torch
...
@@ -6,7 +6,8 @@ import torch
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
from
vllm.spec_decode.util
import
sampler_output_to_torch
from
vllm.spec_decode.util
import
(
maybe_mock_device_tensors
,
sampler_output_to_torch
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
import
Worker
...
@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
...
@@ -69,6 +70,9 @@ class MultiStepWorker(Worker):
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
)
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_append_new_tokens
(
model_output
,
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
copied_seq_group_metadata_list
)
...
@@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer):
...
@@ -341,6 +345,16 @@ class DraftModelTop1Proposer(SpeculativeProposer):
sampler_output
=
maybe_sampler_output
sampler_output
=
maybe_sampler_output
# We mock the device tensors until PR 7/9 is merged (e2e correctness).
# https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
for
step_output
in
sampler_output
:
maybe_mock_device_tensors
(
sampler_output
=
step_output
,
batch_size
=
len
(
proposal_lens
),
vocab_size
=
self
.
_vocab_size
,
device
=
self
.
_device
,
)
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
proposal_tokens
,
proposal_probs
=
sampler_output_to_torch
(
sampler_output
)
sampler_output
)
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
e95cd879
...
@@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple
...
@@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.model_executor.layers.rejection_sampler
import
RejectionSampler
from
vllm.sequence
import
(
SamplerOutput
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
Logprob
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceGroupOutput
,
SequenceOutput
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
...
@@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
...
@@ -13,8 +14,9 @@ from vllm.spec_decode.metrics import AsyncMetricsCollector
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
from
vllm.spec_decode.util
import
(
get_all_seq_ids
,
nvtx_range
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
WorkerBase
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
logger
=
init_logger
(
__name__
)
class
SpecDecodeWorker
(
LoraNotSupportedWorkerBase
):
class
SpecDecodeWorker
(
LoraNotSupportedWorkerBase
):
...
@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -45,10 +47,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit.
"""
"""
@
classmethod
def
from_workers
(
cls
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
WorkerBase
)
->
"SpecDecodeWorker"
:
return
SpecDecodeWorker
(
proposer_worker
,
scorer_worker
,
# TODO(cade) disable strict mode for speedup.
rejection_sampler
=
RejectionSampler
(
strict_mode
=
True
),
)
def
__init__
(
def
__init__
(
self
,
self
,
proposer_worker
:
MultiStepWorker
,
proposer_worker
:
MultiStepWorker
,
scorer_worker
:
Worker
,
scorer_worker
:
Worker
Base
,
rejection_sampler
:
RejectionSampler
,
rejection_sampler
:
RejectionSampler
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
metrics_collector
:
Optional
[
AsyncMetricsCollector
]
=
None
,
):
):
...
@@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -87,6 +99,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
scorer_worker
.
init_device
()
self
.
scorer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
self
.
proposer_worker
.
init_device
()
# NOTE(cade): load_model is not part of the WorkerBase interface.
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
rejection_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
rejection_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
scorer
=
BatchExpansionTop1Scorer
(
self
.
scorer
=
BatchExpansionTop1Scorer
(
...
@@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -131,7 +147,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]],
num_
spec_token
s
:
int
,
num_
lookahead_slot
s
:
int
,
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
"""Perform speculative decoding on the input batch.
"""Perform speculative decoding on the input batch.
"""
"""
...
@@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -140,9 +156,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"speculative decoding "
"requires non-None seq_group_metadata_list"
)
"requires non-None seq_group_metadata_list"
)
logger
.
info
(
f
"spec_decode_worker.execute_model
{
num_lookahead_slots
=
}
"
)
# If no spec tokens, call the proposer and scorer workers normally.
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# Used for prefill.
if
num_
spec_token
s
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
if
num_
lookahead_slot
s
==
0
or
len
(
seq_group_metadata_list
)
==
0
:
return
self
.
_run_no_spec
(
return
self
.
_run_no_spec
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
...
@@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -155,7 +173,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
k
=
num_
spec_token
s
,
k
=
num_
lookahead_slot
s
,
)
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
@
nvtx_range
(
"spec_decode_worker._run_no_spec"
)
...
@@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -170,20 +188,24 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer and scorer model so that the KV cache is consistent between the
proposer and scorer model so that the KV cache is consistent between the
two.
two.
"""
"""
logger
.
info
(
"run proposer worker no spec"
)
self
.
proposer_worker
.
execute_model
(
self
.
proposer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
return_python_output
=
False
)
)
logger
.
info
(
"run target worker no spec"
)
sampler_output
=
self
.
scorer_worker
.
execute_model
(
sampler_output
=
self
.
scorer_worker
.
execute_model
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
)
)
assert
len
(
sampler_output
)
==
1
sampler_output
=
sampler_output
[
0
]
# Clear device tensors from sampler output. This reduces communication
# Clear device tensors from sampler output. This reduces communication
# overhead when the engine runs in a different process than the workers.
# overhead when the engine runs in a different process than the workers.
...
@@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -209,11 +231,13 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
sequence.
"""
"""
logger
.
info
(
"get spec proposals"
)
# Generate proposals using draft worker.
# Generate proposals using draft worker.
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
proposals
=
self
.
proposer_worker
.
get_spec_proposals
(
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
,
k
)
blocks_to_copy
,
k
)
logger
.
info
(
"score proposals"
)
proposal_scores
=
self
.
scorer
.
score_proposals
(
proposal_scores
=
self
.
scorer
.
score_proposals
(
seq_group_metadata_list
,
seq_group_metadata_list
,
blocks_to_swap_in
,
blocks_to_swap_in
,
...
@@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -223,9 +247,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposals
,
proposals
,
)
)
logger
.
info
(
"verify proposals"
)
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
accepted_token_ids
=
self
.
_verify_tokens
(
seq_group_metadata_list
,
proposal_scores
,
proposals
,
k
)
proposal_scores
,
proposals
,
k
)
logger
.
info
(
"create output list"
)
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
return
self
.
_create_output_sampler_list
(
seq_group_metadata_list
,
accepted_token_ids
,
k
)
accepted_token_ids
,
k
)
...
@@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -311,7 +337,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
parent_seq_id
=
seq_id
,
parent_seq_id
=
seq_id
,
output_token
=
token_id
,
output_token
=
token_id
,
# TODO Add verifier logprobs.
# TODO Add verifier logprobs.
logprobs
=
{
token_id
:
0.0
},
logprobs
=
{
token_id
:
Logprob
(
0.0
)
},
)
)
],
],
prompt_logprobs
=
None
,
prompt_logprobs
=
None
,
...
...
vllm/spec_decode/util.py
View file @
e95cd879
...
@@ -82,6 +82,32 @@ def sampler_output_to_torch(
...
@@ -82,6 +82,32 @@ def sampler_output_to_torch(
return
sampled_token_ids
,
sampled_token_probs
return
sampled_token_ids
,
sampled_token_probs
def
maybe_mock_device_tensors
(
sampler_output
:
SamplerOutput
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
str
)
->
None
:
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
values. This will be removed in PR 7/9.
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
"""
values
=
[
sampler_output
.
sampled_token_probs
,
sampler_output
.
sampled_token_ids
]
assert
all
(
v
is
None
for
v
in
values
)
or
not
any
(
v
is
None
for
v
in
values
)
if
not
any
(
v
is
None
for
v
in
values
):
# Do nothing if the tensors are already created (usually in unit tests).
return
# Softmax to ensure valid probs.
sampler_output
.
sampled_token_probs
=
torch
.
nn
.
functional
.
softmax
(
torch
.
rand
(
batch_size
,
vocab_size
,
dtype
=
torch
.
float32
,
device
=
device
),
dim
=-
1
)
sampler_output
.
sampled_token_ids
=
torch
.
randint
(
low
=
10
,
high
=
100
,
size
=
(
batch_size
,
),
dtype
=
torch
.
long
,
device
=
device
)
@
contextmanager
@
contextmanager
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
def
nvtx_range
(
msg
,
*
args
,
**
kwargs
):
"""
"""
...
...
vllm/worker/cpu_worker.py
View file @
e95cd879
...
@@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -251,7 +251,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
=
len
(
seq_group_metadata_list
)
...
@@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
...
@@ -274,11 +274,13 @@ class CPUWorker(LoraNotSupportedWorkerBase):
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
cpu_cache
)
self
.
cpu_cache
)
return
output
# CPU worker only supports single-step execution.
return
[
output
]
def
init_distributed_environment
(
self
)
->
None
:
def
init_distributed_environment
(
self
)
->
None
:
"""Initialize the distributed environment."""
"""Initialize the distributed environment."""
...
...
vllm/worker/neuron_worker.py
View file @
e95cd879
"""A Neuron worker class."""
"""A Neuron worker class."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
...
@@ -73,15 +73,18 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Optional
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
=
len
(
seq_group_metadata_list
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
)
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
)
return
output
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return
[
output
]
def
get_cache_block_size_bytes
(
self
)
->
int
:
def
get_cache_block_size_bytes
(
self
)
->
int
:
"""Determine the size in bytes of a cache block.
"""Determine the size in bytes of a cache block.
...
...
vllm/worker/worker.py
View file @
e95cd879
...
@@ -210,7 +210,9 @@ class Worker(WorkerBase):
...
@@ -210,7 +210,9 @@ class Worker(WorkerBase):
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
num_lookahead_slots
:
int
=
0
,
)
->
List
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
num_seq_groups
=
len
(
seq_group_metadata_list
)
...
@@ -235,11 +237,14 @@ class Worker(WorkerBase):
...
@@ -235,11 +237,14 @@ class Worker(WorkerBase):
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
num_seq_groups
==
0
:
if
num_seq_groups
==
0
:
return
{}
return
[]
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
self
.
gpu_cache
)
self
.
gpu_cache
)
return
output
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return
[
output
]
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
...
...
vllm/worker/worker_base.py
View file @
e95cd879
...
@@ -40,12 +40,13 @@ class WorkerBase(ABC):
...
@@ -40,12 +40,13 @@ class WorkerBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
execute_model
(
self
,
def
execute_model
(
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
SamplerOutput
:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]])
->
List
[
SamplerOutput
]:
"""Executes one model step on the given sequences."""
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment