Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a71900b
Unverified
Commit
0a71900b
authored
Nov 26, 2024
by
Chendi.Xue
Committed by
GitHub
Nov 26, 2024
Browse files
Remove hard-dependencies of Speculative decode to CUDA workers (#10587)
Signed-off-by:
Chendi Xue
<
chendi.xue@intel.com
>
parent
2f0a0a17
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
219 additions
and
77 deletions
+219
-77
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+2
-2
vllm/config.py
vllm/config.py
+1
-0
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+16
-1
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+7
-1
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+3
-1
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+12
-12
vllm/spec_decode/interfaces.py
vllm/spec_decode/interfaces.py
+5
-3
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+5
-4
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+14
-1
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+24
-7
vllm/spec_decode/ngram_worker.py
vllm/spec_decode/ngram_worker.py
+2
-1
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+25
-11
vllm/spec_decode/target_model_runner.py
vllm/spec_decode/target_model_runner.py
+11
-22
vllm/spec_decode/util.py
vllm/spec_decode/util.py
+8
-4
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+37
-2
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+25
-2
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+15
-0
vllm/worker/worker.py
vllm/worker/worker.py
+4
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+3
-0
No files found.
tests/spec_decode/test_spec_decode_worker.py
View file @
0a71900b
...
@@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):
...
@@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):
target_worker
.
init_device
.
assert_called_once
()
target_worker
.
init_device
.
assert_called_once
()
metrics_collector
.
init_
gpu_
tensors
.
assert_called_once
()
metrics_collector
.
init_tensors
.
assert_called_once
()
spec_decode_sampler
.
init_
gpu_
tensors
.
assert_called_once
()
spec_decode_sampler
.
init_tensors
.
assert_called_once
()
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
@
pytest
.
mark
.
parametrize
(
"acceptance_sampler_method"
,
...
...
vllm/config.py
View file @
0a71900b
...
@@ -990,6 +990,7 @@ class ParallelConfig:
...
@@ -990,6 +990,7 @@ class ParallelConfig:
# the full name of the worker class to use. If "auto", the worker class
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
# will be determined based on the platform.
worker_cls
:
str
=
"auto"
worker_cls
:
str
=
"auto"
sd_worker_cls
:
str
=
"auto"
world_size
:
int
=
field
(
init
=
False
)
world_size
:
int
=
field
(
init
=
False
)
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
0a71900b
...
@@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -43,6 +43,21 @@ class SpecDecodeBaseSampler(nn.Module):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
)
device
=
device
)
def
init_tensors
(
self
,
device
:
Union
[
int
,
str
],
device_type
:
Union
[
torch
.
device
,
str
]
=
'cuda'
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
if
isinstance
(
device_type
,
torch
.
device
):
device_type
=
device_type
.
type
if
isinstance
(
device
,
int
):
device
=
f
"
{
device_type
}
:
{
device
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
@
property
@
property
def
probs_dtype
(
self
):
def
probs_dtype
(
self
):
return
torch
.
float32
return
torch
.
float32
...
@@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
...
@@ -77,7 +92,7 @@ class SpecDecodeBaseSampler(nn.Module):
tensor is [batch_size, k + num_bonus_tokens]
tensor is [batch_size, k + num_bonus_tokens]
"""
"""
batch_size
,
k
=
substitute_token_ids
.
shape
batch_size
,
k
=
substitute_token_ids
.
shape
bonus_token_ids
=
bonus_token_ids
.
squeeze
()
bonus_token_ids
=
bonus_token_ids
.
squeeze
(
-
1
)
# Determine the index of the first False value for each row.
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
...
...
vllm/platforms/cpu.py
View file @
0a71900b
...
@@ -86,4 +86,10 @@ class CpuPlatform(Platform):
...
@@ -86,4 +86,10 @@ class CpuPlatform(Platform):
parallel_config
.
distributed_executor_backend
)
parallel_config
.
distributed_executor_backend
)
parallel_config
.
distributed_executor_backend
=
"mp"
parallel_config
.
distributed_executor_backend
=
"mp"
if
parallel_config
.
worker_cls
==
"auto"
:
if
parallel_config
.
worker_cls
==
"auto"
:
parallel_config
.
worker_cls
=
"vllm.worker.cpu_worker.CPUWorker"
if
vllm_config
.
speculative_config
:
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config
.
sd_worker_cls
=
\
"vllm.worker.cpu_worker.CPUWorker"
else
:
parallel_config
.
worker_cls
=
"vllm.worker.cpu_worker.CPUWorker"
vllm/platforms/cuda.py
View file @
0a71900b
...
@@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
...
@@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
elif
vllm_config
.
speculative_config
:
elif
vllm_config
.
speculative_config
:
parallel_config
.
worker_cls
=
\
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config
.
sd_worker_cls
=
\
"vllm.worker.worker.Worker"
else
:
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
...
@@ -236,4 +238,4 @@ try:
...
@@ -236,4 +238,4 @@ try:
if
not
isinstance
(
pynvml
,
_MockModule
):
if
not
isinstance
(
pynvml
,
_MockModule
):
CudaPlatform
.
log_warnings
()
CudaPlatform
.
log_warnings
()
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
CudaPlatform
.
log_warnings
()
CudaPlatform
.
log_warnings
()
\ No newline at end of file
vllm/spec_decode/draft_model_runner.py
View file @
0a71900b
...
@@ -20,8 +20,9 @@ except (ModuleNotFoundError, ImportError) as err:
...
@@ -20,8 +20,9 @@ except (ModuleNotFoundError, ImportError) as err:
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunner
)
ModelRunnerInputBase
,
ModelRunnerWrapperBase
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -33,7 +34,7 @@ debug_advance_input = False
...
@@ -33,7 +34,7 @@ debug_advance_input = False
allow_gpu_advance_step
=
True
allow_gpu_advance_step
=
True
class
TP1DraftModelRunner
(
ModelRunner
):
class
TP1DraftModelRunner
(
ModelRunner
WrapperBase
):
"""Specialized model runner for speculative decoding draft model.
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
generate k speculative tokens in a single speculative decoding step,
...
@@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model).
any broadcasting inside execute_model).
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
model_runner
:
ModelRunnerBase
):
if
kwargs
.
get
(
"return_hidden_states"
):
if
hasattr
(
model_runner
,
"return_hidden_states"
)
and
model_runner
.
return_hidden_states
:
raise
ValueError
(
raise
ValueError
(
"return_hidden_states is not supported for TP1DraftModelRunner."
"return_hidden_states is not supported for TP1DraftModelRunner."
)
)
super
().
__init__
(
model_runner
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
indices_of_seq_with_bonus_tokens
=
None
self
.
indices_of_seq_with_bonus_tokens
=
None
...
@@ -73,10 +75,8 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -73,10 +75,8 @@ class TP1DraftModelRunner(ModelRunner):
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
prompt_logprob_indices
==
[]
# No prompt
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
assert
seq_group
.
sample_indices
==
[
i
]
# Simple
def
_gpu_advance_step
(
def
_gpu_advance_step
(
self
,
model_input
:
ModelRunnerInputBase
,
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
last_output
:
SamplerOutput
)
->
ModelRunnerInputBase
:
last_output
:
SamplerOutput
)
->
ModelInputForGPUWithSamplingMetadata
:
# Currently, we expect "decode mode" only
# Currently, we expect "decode mode" only
assert
not
model_input
.
is_prompt
assert
not
model_input
.
is_prompt
...
@@ -168,7 +168,7 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -168,7 +168,7 @@ class TP1DraftModelRunner(ModelRunner):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
model_input
:
Model
InputForGPUWithSamplingMetadata
,
model_input
:
Model
RunnerInputBase
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
...
...
vllm/spec_decode/interfaces.py
View file @
0a71900b
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Set
from
typing
import
Optional
,
Set
,
Union
import
torch
import
torch
...
@@ -75,9 +75,11 @@ class SpeculativeProposer(ABC):
...
@@ -75,9 +75,11 @@ class SpeculativeProposer(ABC):
class
SpeculativeScorer
(
ABC
):
class
SpeculativeScorer
(
ABC
):
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
device
:
str
,
def
__init__
(
self
,
scorer_worker
:
WorkerBase
,
vocab_size
:
int
):
device
:
Union
[
torch
.
device
,
str
],
vocab_size
:
int
):
self
.
_scorer_worker
=
scorer_worker
self
.
_scorer_worker
=
scorer_worker
if
isinstance
(
device
,
torch
.
device
):
device
=
device
.
type
self
.
_device
=
device
self
.
_device
=
device
self
.
_vocab_size
=
vocab_size
self
.
_vocab_size
=
vocab_size
...
...
vllm/spec_decode/medusa_worker.py
View file @
0a71900b
...
@@ -9,21 +9,22 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
...
@@ -9,21 +9,22 @@ from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
Worker
WrapperBase
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
Worker
):
class
MedusaWorker
(
NonLLMProposerWorkerBase
,
Worker
WrapperBase
):
"""Worker for Medusa.
"""Worker for Medusa.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
kwargs
.
get
(
"vllm_config"
))
self
.
init_worker
(
*
args
,
**
kwargs
)
# Lazy initialization list.
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
self
.
_proposer
:
Top1Proposer
def
init_device
(
self
):
def
init_device
(
self
):
s
uper
()
.
init_device
()
s
elf
.
worker
.
init_device
()
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
...
...
vllm/spec_decode/metrics.py
View file @
0a71900b
import
time
import
time
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
,
Union
import
msgspec
import
msgspec
import
torch
import
torch
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeBaseSampler
)
SpecDecodeBaseSampler
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -81,8 +82,20 @@ class AsyncMetricsCollector:
...
@@ -81,8 +82,20 @@ class AsyncMetricsCollector:
self
.
_rank
=
rank
self
.
_rank
=
rank
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
def
init_tensors
(
self
,
rank
:
int
,
device_type
:
Union
[
torch
.
device
,
str
]
=
'cuda'
)
->
None
:
self
.
_rank
=
rank
if
isinstance
(
device_type
,
torch
.
device
):
device_type
=
device_type
.
type
if
device_type
==
'cuda'
:
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
def
maybe_collect_rejsample_metrics
(
def
maybe_collect_rejsample_metrics
(
self
,
k
:
int
)
->
Optional
[
SpecDecodeWorkerMetrics
]:
self
,
k
:
int
)
->
Optional
[
SpecDecodeWorkerMetrics
]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if
not
current_platform
.
is_cuda_alike
():
return
None
# If a copy was initiated in the previous call, collect and return.
# If a copy was initiated in the previous call, collect and return.
if
self
.
_in_flight_copy
is
not
None
:
if
self
.
_in_flight_copy
is
not
None
:
...
...
vllm/spec_decode/multi_step_worker.py
View file @
0a71900b
...
@@ -5,17 +5,21 @@ from typing import Dict, List, Set, Tuple
...
@@ -5,17 +5,21 @@ from typing import Dict, List, Set, Tuple
import
torch
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
SequenceData
,
SequenceGroupMetadata
)
SequenceGroupMetadata
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
Worker
WrapperBase
class
MultiStepWorker
(
Worker
,
ProposerWorkerBase
):
class
MultiStepWorker
(
ProposerWork
erBase
,
WorkerWrapp
erBase
):
"""The MultiStepWorker is equivalent to a Worker except that it allows
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
allocated enough space to store the additional KV. This reduces overhead
...
@@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
kwargs
.
get
(
"vllm_config"
))
self
.
init_worker
(
*
args
,
**
kwargs
)
# Lazy initialization list.
# Lazy initialization list.
self
.
_proposer
:
SpeculativeProposer
self
.
_proposer
:
SpeculativeProposer
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
s
uper
()
.
init_device
()
s
elf
.
worker
.
init_device
()
self
.
_proposer
=
Top1Proposer
(
self
.
_proposer
=
Top1Proposer
(
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
weakref
.
proxy
(
self
),
# type: ignore[arg-type]
...
@@ -51,6 +56,18 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -51,6 +56,18 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
self
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
=
(
self
.
model_runner
.
model
.
sampler
.
should_modify_greedy_probs_inplace
=
(
True
)
True
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
worker
.
determine_num_available_blocks
()
def
get_cache_block_size_bytes
(
self
)
->
int
:
return
self
.
worker
.
get_cache_block_size_bytes
()
def
initialize_cache
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
worker
.
initialize_cache
(
*
args
,
**
kwargs
)
def
execute_model
(
self
,
*
args
,
**
kwargs
)
->
List
[
SamplerOutput
]:
return
self
.
worker
.
execute_model
(
*
args
,
**
kwargs
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
sampler_output
(
def
sampler_output
(
self
,
self
,
...
@@ -75,7 +92,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -75,7 +92,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Run model sample_len times.
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
model_outputs
:
List
[
SamplerOutput
]
=
[]
if
isinstance
(
if
current_platform
.
is_cuda_alike
()
and
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
self
.
model_runner
,
TP1DraftModelRunner
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
)
and
self
.
model_runner
.
supports_gpu_multi_step
(
expanded_request
):
# Here we run the draft_model_runner with multi-step prepare
# Here we run the draft_model_runner with multi-step prepare
...
@@ -92,7 +109,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
...
@@ -92,7 +109,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# and other restrictions that are part of DraftModelRunner's
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
# supports_gpu_multi_step(..)
for
_
in
range
(
sample_len
):
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
s
uper
()
.
execute_model
(
model_output
:
List
[
SamplerOutput
]
=
s
elf
.
worker
.
execute_model
(
execute_model_req
=
expanded_request
)
execute_model_req
=
expanded_request
)
assert
(
len
(
model_output
)
==
1
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
),
"composing multistep workers not supported"
...
...
vllm/spec_decode/ngram_worker.py
View file @
0a71900b
...
@@ -22,6 +22,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
...
@@ -22,6 +22,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
# Get local_rank/vocab_size from kwargs attribute
# Get local_rank/vocab_size from kwargs attribute
self
.
local_rank
=
kwargs
[
"local_rank"
]
self
.
local_rank
=
kwargs
[
"local_rank"
]
self
.
vocab_size
=
kwargs
[
"vllm_config"
].
model_config
.
get_vocab_size
()
self
.
vocab_size
=
kwargs
[
"vllm_config"
].
model_config
.
get_vocab_size
()
self
.
device_type
=
kwargs
.
get
(
"device_type"
,
"cuda"
)
# Lazy initialization list.
# Lazy initialization list.
self
.
_proposer
:
Top1Proposer
self
.
_proposer
:
Top1Proposer
...
@@ -34,7 +35,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
...
@@ -34,7 +35,7 @@ class NGramWorker(NonLLMProposerWorkerBase):
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
self
.
ngram_prompt_lookup_min
=
ngram_prompt_lookup_min
def
init_device
(
self
):
def
init_device
(
self
):
self
.
device
=
torch
.
device
(
f
"
cuda
:
{
self
.
local_rank
}
"
)
self
.
device
=
torch
.
device
(
f
"
{
self
.
device_type
}
:
{
self
.
local_rank
}
"
)
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
self
.
load_model
=
lambda
*
args
,
**
kwargs
:
None
# Current NGramWorker only supports Top1Proposer
# Current NGramWorker only supports Top1Proposer
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
0a71900b
...
@@ -14,12 +14,16 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
...
@@ -14,12 +14,16 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
SpecDecodeBaseSampler
,
SpecDecodeStochasticBaseSampler
)
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
from
vllm.model_executor.layers.typical_acceptance_sampler
import
(
TypicalAcceptanceSampler
)
TypicalAcceptanceSampler
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
CompletionSequenceGroupOutput
,
ExecuteModelRequest
,
HiddenStates
,
SequenceGroupMetadata
,
HiddenStates
,
SequenceGroupMetadata
,
get_all_seq_ids_and_request_ids
)
get_all_seq_ids_and_request_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
if
current_platform
.
is_cuda_alike
():
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
from
vllm.spec_decode.medusa_worker
import
MedusaWorker
...
@@ -36,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
...
@@ -36,8 +40,8 @@ from vllm.spec_decode.util import (Timer, create_logprobs_output,
get_all_num_logprobs
,
get_all_num_logprobs
,
get_sampled_token_logprobs
,
nvtx_range
,
get_sampled_token_logprobs
,
nvtx_range
,
split_batch_by_proposal_len
)
split_batch_by_proposal_len
)
from
vllm.worker.worker
import
Worker
from
vllm.worker.worker
_base
import
(
LoraNotSupportedWorkerBase
,
WorkerBase
,
from
vllm.worker.worker_base
import
LoraNotSupportedWorkerBase
,
Work
erBase
WorkerWrapp
erBase
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -53,7 +57,11 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_kwargs
=
kwargs
.
copy
()
draft_worker_kwargs
=
kwargs
.
copy
()
kwargs
[
"model_runner_cls"
]
=
TargetModelRunner
kwargs
[
"model_runner_cls"
]
=
TargetModelRunner
target_worker
=
Worker
(
*
args
,
**
kwargs
)
target_worker_config
=
copy
.
deepcopy
(
vllm_config
)
target_worker_config
.
parallel_config
.
worker_cls
=
\
target_worker_config
.
parallel_config
.
sd_worker_cls
target_worker
=
WorkerWrapperBase
(
vllm_config
=
target_worker_config
)
target_worker
.
init_worker
(
*
args
,
**
kwargs
)
# Set the disable_logprobs variable in the TargetModelRunner instance
# Set the disable_logprobs variable in the TargetModelRunner instance
# as per its value specified in the SpeculativeConfig.
# as per its value specified in the SpeculativeConfig.
target_worker
.
model_runner
.
disable_logprobs
=
\
target_worker
.
model_runner
.
disable_logprobs
=
\
...
@@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -65,6 +73,8 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config
.
model_config
,
draft_worker_config
.
model_config
,
vllm_config
.
load_config
,
vllm_config
.
load_config
,
)
)
speculative_config
.
draft_parallel_config
.
worker_cls
=
\
draft_worker_config
.
parallel_config
.
sd_worker_cls
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
# TODO allow draft-model specific load config.
# TODO allow draft-model specific load config.
...
@@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -125,7 +135,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@
classmethod
@
classmethod
def
create_worker
(
def
create_worker
(
cls
,
cls
,
scorer_worker
:
Worker
,
scorer_worker
:
Worker
Base
,
draft_worker_kwargs
:
Dict
[
str
,
Any
],
draft_worker_kwargs
:
Dict
[
str
,
Any
],
disable_mqa_scorer
:
bool
,
disable_mqa_scorer
:
bool
,
disable_by_batch_size
:
Optional
[
int
],
disable_by_batch_size
:
Optional
[
int
],
...
@@ -145,6 +155,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -145,6 +155,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_parallel_config
:
ParallelConfig
=
draft_worker_kwargs
[
draft_parallel_config
:
ParallelConfig
=
draft_worker_kwargs
[
'vllm_config'
].
parallel_config
'vllm_config'
].
parallel_config
if
ngram_prompt_lookup_max
>
0
:
if
ngram_prompt_lookup_max
>
0
:
draft_worker_kwargs
[
"device_type"
]
=
scorer_worker
.
device_config
.
device
.
type
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
NGramWorker
(
**
draft_worker_kwargs
)
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
proposer_worker
.
set_ngram_window_size
(
ngram_prompt_lookup_min
,
ngram_prompt_lookup_max
)
ngram_prompt_lookup_max
)
...
@@ -158,8 +170,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -158,8 +170,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
MedusaWorker
(
**
draft_worker_kwargs
)
else
:
else
:
if
draft_tp
==
1
:
if
draft_tp
==
1
:
draft_worker_kwargs
[
if
current_platform
.
is_cuda_alike
():
"model_runner_cls"
]
=
TP1DraftModelRunner
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
else
:
else
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
if
draft_model_config
.
hf_config
.
model_type
==
"eagle"
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -306,8 +319,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -306,8 +319,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self
.
scorer_worker
.
load_model
()
self
.
scorer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
self
.
proposer_worker
.
load_model
()
self
.
_metrics
.
init_gpu_tensors
(
self
.
rank
)
self
.
_metrics
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
self
.
spec_decode_sampler
.
init_gpu_tensors
(
self
.
rank
)
self
.
spec_decode_sampler
.
init_tensors
(
self
.
rank
,
device_type
=
self
.
device
)
scorer_cls
:
Type
[
SpeculativeScorer
]
scorer_cls
:
Type
[
SpeculativeScorer
]
if
self
.
disable_mqa_scorer
:
if
self
.
disable_mqa_scorer
:
...
@@ -1111,11 +1125,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -1111,11 +1125,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
raise
NotImplementedError
raise
NotImplementedError
def
start_profile
(
self
):
def
start_profile
(
self
):
if
isinstance
(
self
.
scorer_worker
,
Worker
):
if
isinstance
(
self
.
scorer_worker
,
Worker
Base
):
self
.
scorer_worker
.
start_profile
()
self
.
scorer_worker
.
start_profile
()
def
stop_profile
(
self
):
def
stop_profile
(
self
):
if
isinstance
(
self
.
scorer_worker
,
Worker
):
if
isinstance
(
self
.
scorer_worker
,
Worker
Base
):
self
.
scorer_worker
.
stop_profile
()
self
.
scorer_worker
.
stop_profile
()
...
...
vllm/spec_decode/target_model_runner.py
View file @
0a71900b
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
vllm.config
import
VllmConfig
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunner
)
ModelRunnerInputBase
,
ModelRunnerWrapperBase
)
class
TargetModelRunner
(
ModelRunner
):
class
TargetModelRunner
(
ModelRunner
WrapperBase
):
"""Specialized model runner for speculative decoding target model.
"""Specialized model runner for speculative decoding target model.
In speculative decoding, the log probabilities selected finally may not
In speculative decoding, the log probabilities selected finally may not
be the same ones as selected by the target model sampling. This means
be the same ones as selected by the target model sampling. This means
...
@@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner):
...
@@ -18,32 +18,21 @@ class TargetModelRunner(ModelRunner):
requested or not.
requested or not.
"""
"""
def
__init__
(
def
__init__
(
self
,
model_runner
:
ModelRunnerBase
):
self
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
):
# An internal boolean member variable to indicate if token log
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
# probabilities are needed or not.
super
().
__init__
(
model_runner
)
self
.
disable_logprobs
=
True
self
.
disable_logprobs
=
True
super
().
__init__
(
vllm_config
=
vllm_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
return_hidden_states
=
return_hidden_states
,
)
def
prepare_model_input
(
def
prepare_model_input
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
virtual_engine
:
int
=
0
,
virtual_engine
:
int
=
0
,
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
finished_requests_ids
:
Optional
[
List
[
str
]]
=
None
,
)
->
Model
InputForGPUWithSamplingMetadata
:
)
->
Model
RunnerInputBase
:
model_input
:
Model
InputForGPUWithSamplingMetadata
=
super
(
model_input
:
Model
RunnerInputBase
=
\
).
prepare_model_input
(
seq_group_metadata_list
,
virtual_engine
,
self
.
model_runner
.
prepare_model_input
(
finished_requests_ids
)
seq_group_metadata_list
,
virtual_engine
,
finished_requests_ids
)
# If token log probabilities is disabled then skip generating sampler
# If token log probabilities is disabled then skip generating sampler
# CPU output. We directly serialize the GPU sampled_token_id tensors
# CPU output. We directly serialize the GPU sampled_token_id tensors
# as needed. If log probabilities is enabled then synchronize all the
# as needed. If log probabilities is enabled then synchronize all the
...
...
vllm/spec_decode/util.py
View file @
0a71900b
...
@@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
...
@@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Sequence, Tuple
import
torch
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
PromptLogprobs
,
SequenceGroupMetadata
,
PromptLogprobs
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceOutput
)
...
@@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
...
@@ -247,11 +248,14 @@ def nvtx_range(msg, *args, **kwargs):
Arguments:
Arguments:
msg (string): message to associate with the range
msg (string): message to associate with the range
"""
"""
torch
.
cuda
.
nvtx
.
range_push
(
msg
.
format
(
*
args
,
**
kwargs
))
if
current_platform
.
is_cuda_alike
():
try
:
torch
.
cuda
.
nvtx
.
range_push
(
msg
.
format
(
*
args
,
**
kwargs
))
try
:
yield
finally
:
torch
.
cuda
.
nvtx
.
range_pop
()
else
:
yield
yield
finally
:
torch
.
cuda
.
nvtx
.
range_pop
()
class
Timer
:
class
Timer
:
...
...
vllm/worker/cpu_model_runner.py
View file @
0a71900b
...
@@ -80,6 +80,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
...
@@ -80,6 +80,7 @@ class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
Used by the ModelRunner.
Used by the ModelRunner.
"""
"""
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
sampling_metadata
:
Optional
[
"SamplingMetadata"
]
=
None
is_prompt
:
Optional
[
bool
]
=
None
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
as_broadcastable_tensor_dict
(
self
)
->
Dict
[
str
,
Any
]:
tensor_dict
=
{
tensor_dict
=
{
...
@@ -395,6 +396,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
...
@@ -395,6 +396,7 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
return_hidden_states
:
bool
=
False
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -403,19 +405,25 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
...
@@ -403,19 +405,25 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
cache_config
=
self
.
cache_config
cache_config
=
self
.
cache_config
self
.
is_driver_worker
=
is_driver_worker
self
.
is_driver_worker
=
is_driver_worker
self
.
return_hidden_states
=
return_hidden_states
self
.
device
=
self
.
device_config
.
device
self
.
device
=
self
.
device_config
.
device
self
.
pin_memory
=
False
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
num_attn_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
)
needs_attn_backend
=
(
num_attn_heads
!=
0
or
self
.
model_config
.
is_attention_free
)
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
self
.
model_config
.
is_attention_free
,
)
)
if
needs_attn_backend
else
None
# Multi-modal data support
# Multi-modal data support
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
...
@@ -444,6 +452,15 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
...
@@ -444,6 +452,15 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
return
builder
.
build
()
# type: ignore
return
builder
.
build
()
# type: ignore
# sampler property will be used by spec_decode_worker
@
property
def
sampler
(
self
):
return
self
.
model
.
sampler
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_config
.
get_vocab_size
()
class
CPUModelRunner
(
CPUModelRunnerBase
[
ModelInputForCPUWithSamplingMetadata
]):
class
CPUModelRunner
(
CPUModelRunnerBase
[
ModelInputForCPUWithSamplingMetadata
]):
_model_input_cls
:
Type
[
ModelInputForCPUWithSamplingMetadata
]
=
(
_model_input_cls
:
Type
[
ModelInputForCPUWithSamplingMetadata
]
=
(
...
@@ -480,9 +497,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -480,9 +497,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
pin_memory
=
False
,
pin_memory
=
False
,
generators
=
generators
)
generators
=
generators
)
is_prompt
=
(
seq_group_metadata_list
[
0
].
is_prompt
if
seq_group_metadata_list
else
None
)
return
dataclasses
.
replace
(
model_input
,
return
dataclasses
.
replace
(
model_input
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
virtual_engine
=
virtual_engine
)
virtual_engine
=
virtual_engine
,
is_prompt
=
is_prompt
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
execute_model
(
def
execute_model
(
...
@@ -491,16 +511,22 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -491,16 +511,22 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
num_steps
:
int
=
1
,
num_steps
:
int
=
1
,
previous_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
List
[
SamplerOutput
]]:
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
if
num_steps
>
1
:
raise
ValueError
(
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
model_executable
=
self
.
model
multimodal_kwargs
=
{}
multimodal_kwargs
=
{}
if
model_input
.
multi_modal_kwargs
is
not
None
:
if
model_input
.
multi_modal_kwargs
is
not
None
:
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
multimodal_kwargs
=
MultiModalKwargs
.
as_kwargs
(
model_input
.
multi_modal_kwargs
,
device
=
self
.
device
)
model_input
.
multi_modal_kwargs
,
device
=
self
.
device
)
execute_model_kwargs
=
{}
if
previous_hidden_states
is
not
None
:
execute_model_kwargs
.
update
(
{
"previous_hidden_states"
:
previous_hidden_states
})
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
...
@@ -509,6 +535,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -509,6 +535,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
execute_model_kwargs
,
**
multimodal_kwargs
,
**
multimodal_kwargs
,
)
)
...
@@ -525,4 +552,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -525,4 +552,12 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
logits
=
logits
,
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
)
if
self
.
return_hidden_states
:
# we only need to pass hidden states of most recent token
if
model_input
.
is_prompt
:
output
.
prefill_hidden_states
=
hidden_states
output
.
hidden_states
=
hidden_states
return
[
output
]
return
[
output
]
def
generate_proposals
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
.
generate_proposals
(
*
args
,
**
kwargs
)
vllm/worker/cpu_worker.py
View file @
0a71900b
...
@@ -128,6 +128,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -128,6 +128,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
distributed_init_method
:
str
,
distributed_init_method
:
str
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
is_driver_worker
:
bool
=
False
,
model_runner_cls
:
Optional
[
Type
[
CPUModelRunner
]]
=
None
,
)
->
None
:
)
->
None
:
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
WorkerBase
.
__init__
(
self
,
vllm_config
=
vllm_config
)
...
@@ -151,6 +152,16 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -151,6 +152,16 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
else
:
else
:
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
self
.
local_omp_cpuid
=
omp_cpuids
.
split
(
"|"
)[
rank
]
# Return hidden states from target model if the draft model is an
# mlp_speculator
speculative_config
=
self
.
speculative_config
model_config
=
self
.
model_config
speculative_args
=
{}
if
speculative_config
is
None
\
or
(
speculative_config
.
draft_model_config
.
model
==
model_config
.
model
)
\
or
(
speculative_config
.
draft_model_config
.
hf_config
.
model_type
not
in
[
"medusa"
,
"mlp_speculator"
,
"eagle"
])
\
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
CPUModelRunnerBase
]
=
CPUModelRunner
ModelRunnerClass
:
Type
[
CPUModelRunnerBase
]
=
CPUModelRunner
if
self
.
model_config
.
task
==
"embedding"
:
if
self
.
model_config
.
task
==
"embedding"
:
ModelRunnerClass
=
CPUEmbeddingModelRunner
ModelRunnerClass
=
CPUEmbeddingModelRunner
...
@@ -159,7 +170,11 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -159,7 +170,11 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
model_runner
:
CPUModelRunnerBase
=
ModelRunnerClass
(
self
.
model_runner
:
CPUModelRunnerBase
=
ModelRunnerClass
(
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
)
is_driver_worker
=
is_driver_worker
,
**
speculative_args
,
)
if
model_runner_cls
is
not
None
:
self
.
model_runner
=
model_runner_cls
(
self
.
model_runner
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
# initialize_cache.
self
.
cache_engine
:
List
[
CPUCacheEngine
]
self
.
cache_engine
:
List
[
CPUCacheEngine
]
...
@@ -197,7 +212,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -197,7 +212,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
ret
=
torch
.
ops
.
_C_utils
.
init_cpu_threads_env
(
self
.
local_omp_cpuid
)
if
ret
:
if
ret
:
logger
.
info
(
ret
)
logger
.
info
(
ret
)
self
.
device
=
torch
.
device
(
"cpu"
)
self
.
init_distributed_environment
()
self
.
init_distributed_environment
()
# Set random seed.
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
...
@@ -297,6 +312,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -297,6 +312,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
def
kv_cache
(
self
)
->
Optional
[
List
[
List
[
torch
.
Tensor
]]]:
return
self
.
cpu_cache
return
self
.
cpu_cache
@
property
def
vocab_size
(
self
)
->
int
:
return
self
.
model_runner
.
vocab_size
@
property
def
max_model_len
(
self
)
->
int
:
return
self
.
model_config
.
max_model_len
def
execute_worker
(
def
execute_worker
(
self
,
self
,
worker_input
:
WorkerInput
,
worker_input
:
WorkerInput
,
...
...
vllm/worker/model_runner_base.py
View file @
0a71900b
...
@@ -289,3 +289,18 @@ class ModelRunnerBase(ABC, Generic[T]):
...
@@ -289,3 +289,18 @@ class ModelRunnerBase(ABC, Generic[T]):
self
.
generators
.
pop
(
request_id
,
None
)
self
.
generators
.
pop
(
request_id
,
None
)
return
self
.
generators
return
self
.
generators
class
ModelRunnerWrapperBase
:
"""
The whole point of this class is to lazily initialize the model_runner.
"""
def
__init__
(
self
,
moderl_runner
:
ModelRunnerBase
,
)
->
None
:
self
.
model_runner
:
ModelRunnerBase
=
moderl_runner
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
model_runner
,
attr
)
vllm/worker/worker.py
View file @
0a71900b
...
@@ -74,9 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -74,9 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
else
{
"return_hidden_states"
:
True
}
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
if
model_runner_cls
is
not
None
:
if
model_config
.
task
==
"embedding"
:
ModelRunnerClass
=
model_runner_cls
elif
model_config
.
task
==
"embedding"
:
ModelRunnerClass
=
EmbeddingModelRunner
ModelRunnerClass
=
EmbeddingModelRunner
elif
self
.
model_config
.
is_encoder_decoder
:
elif
self
.
model_config
.
is_encoder_decoder
:
ModelRunnerClass
=
EncoderDecoderModelRunner
ModelRunnerClass
=
EncoderDecoderModelRunner
...
@@ -86,6 +84,9 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -86,6 +84,9 @@ class Worker(LocalOrDistributedWorkerBase):
is_driver_worker
=
is_driver_worker
,
is_driver_worker
=
is_driver_worker
,
**
speculative_args
,
**
speculative_args
,
)
)
if
model_runner_cls
is
not
None
:
self
.
model_runner
=
model_runner_cls
(
self
.
model_runner
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
# initialize_cache.
self
.
cache_engine
:
List
[
CacheEngine
]
self
.
cache_engine
:
List
[
CacheEngine
]
...
...
vllm/worker/worker_base.py
View file @
0a71900b
...
@@ -466,6 +466,9 @@ class WorkerWrapperBase:
...
@@ -466,6 +466,9 @@ class WorkerWrapperBase:
logger
.
exception
(
msg
)
logger
.
exception
(
msg
)
raise
e
raise
e
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
worker
,
attr
)
def
extract_previous_hidden_states
(
def
extract_previous_hidden_states
(
data
:
Union
[
ExecuteModelRequest
,
Dict
[
str
,
torch
.
Tensor
]])
->
\
data
:
Union
[
ExecuteModelRequest
,
Dict
[
str
,
torch
.
Tensor
]])
->
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment