Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2c62023
Unverified
Commit
b2c62023
authored
Jun 28, 2024
by
Cody Yu
Committed by
GitHub
Jun 28, 2024
Browse files
[Spec Decode] Introduce DraftModelRunner (#5799)
parent
b90d8cd8
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
258 additions
and
37 deletions
+258
-37
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+3
-0
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+4
-1
vllm/sequence.py
vllm/sequence.py
+3
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+170
-0
vllm/spec_decode/multi_step_worker.py
vllm/spec_decode/multi_step_worker.py
+16
-13
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+3
-0
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+8
-3
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+11
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+7
-3
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+2
-1
vllm/worker/neuron_model_runner.py
vllm/worker/neuron_model_runner.py
+7
-2
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+7
-2
vllm/worker/worker.py
vllm/worker/worker.py
+4
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+5
-4
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+8
-3
No files found.
tests/spec_decode/test_multi_step_worker.py
View file @
b2c62023
...
...
@@ -7,6 +7,7 @@ import torch
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
Logprob
,
SamplerOutput
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.worker.worker
import
Worker
...
...
@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
worker
=
create_worker
(
Worker
,
...
...
@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
block_size
,
num_gpu_blocks
,
seed
,
model_runner_cls
=
TP1DraftModelRunner
,
)
worker
=
create_worker
(
...
...
tests/spec_decode/utils.py
View file @
b2c62023
...
...
@@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.worker.worker
import
Worker
T
=
TypeVar
(
"T"
,
bound
=
Worker
)
...
...
@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
num_gpu_blocks
:
int
,
seed
:
int
,
is_driver_worker
:
bool
=
True
,
enforce_eager
:
bool
=
True
)
->
T
:
enforce_eager
:
bool
=
True
,
model_runner_cls
:
Optional
[
ModelRunner
]
=
None
)
->
T
:
engine_args
=
EngineArgs
(
model
=
model_name
,
seed
=
seed
,
...
...
@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
model_runner_cls
=
model_runner_cls
,
)
worker
.
init_device
()
...
...
vllm/sequence.py
View file @
b2c62023
...
...
@@ -880,6 +880,8 @@ class ExecuteModelRequest:
running_queue_size
:
int
=
0
# Optional hidden states from prior step.
previous_hidden_states
:
Optional
[
HiddenStates
]
=
None
# The number of forward steps to run.
num_steps
:
int
=
1
def
clone
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
...
...
@@ -893,4 +895,5 @@ class ExecuteModelRequest:
num_lookahead_slots
=
self
.
num_lookahead_slots
,
running_queue_size
=
self
.
running_queue_size
,
previous_hidden_states
=
self
.
previous_hidden_states
,
num_steps
=
self
.
num_steps
,
)
vllm/spec_decode/draft_model_runner.py
0 → 100644
View file @
b2c62023
from
typing
import
List
,
Optional
import
torch
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
ModelRunner
)
logger
=
init_logger
(
__name__
)
class
TP1DraftModelRunner
(
ModelRunner
):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
device_config
:
DeviceConfig
,
cache_config
:
CacheConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
kv_cache_dtype
:
Optional
[
str
]
=
"auto"
,
is_driver_worker
:
bool
=
False
,
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
return_hidden_states
:
bool
=
False
,
):
if
return_hidden_states
:
raise
ValueError
(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super
().
__init__
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
cache_config
=
cache_config
,
load_config
=
load_config
,
lora_config
=
lora_config
,
kv_cache_dtype
=
kv_cache_dtype
,
is_driver_worker
=
is_driver_worker
,
vision_language_config
=
vision_language_config
,
return_hidden_states
=
return_hidden_states
,
)
# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self
.
cached_seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
def
prepare_model_input
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
ModelInputForGPUWithSamplingMetadata
:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self
.
cached_seq_group_metadata_list
=
seq_group_metadata_list
return
super
().
prepare_model_input
(
seq_group_metadata_list
)
def
update_model_input
(
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
last_output
:
SamplerOutput
)
->
ModelInputForGPUWithSamplingMetadata
:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
"""
# Append the output token to the sequence data.
assert
self
.
cached_seq_group_metadata_list
is
not
None
for
seq_group_metadata
,
sequence_group_outputs
in
zip
(
self
.
cached_seq_group_metadata_list
,
last_output
.
outputs
):
seq_group_metadata
.
is_prompt
=
False
for
seq_output
in
sequence_group_outputs
.
samples
:
seq
=
seq_group_metadata
.
seq_data
[
seq_output
.
parent_seq_id
]
token_id
=
seq_output
.
output_token
token_logprob
=
seq_output
.
logprobs
[
token_id
]
seq
.
append_token_id
(
token_id
,
token_logprob
.
logprob
)
seq
.
update_num_computed_tokens
(
1
)
return
self
.
prepare_model_input
(
self
.
cached_seq_group_metadata_list
)
@
torch
.
inference_mode
()
def
execute_model
(
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if
not
self
.
is_driver_worker
:
raise
ValueError
(
"TP1DraftModelRunner only supports TP=1."
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
self
.
set_active_loras
(
model_input
.
lora_requests
,
model_input
.
lora_mapping
)
outputs
:
List
[
SamplerOutput
]
=
[]
for
step
in
range
(
num_steps
):
# Currently cuda graph is only supported by the decode phase.
assert
model_input
.
attn_metadata
is
not
None
prefill_meta
=
model_input
.
attn_metadata
.
prefill_metadata
decode_meta
=
model_input
.
attn_metadata
.
decode_metadata
if
prefill_meta
is
None
and
decode_meta
.
use_cuda_graph
:
assert
model_input
.
input_tokens
is
not
None
graph_batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_executable
=
self
.
graph_runners
[
graph_batch_size
]
else
:
model_executable
=
self
.
model
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
**
multi_modal_kwargs
,
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
model_input
.
sampling_metadata
)
# Sample the next token.
outputs
.
append
(
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
))
# Prepare the inputs for the next step.
if
step
!=
num_steps
-
1
:
model_input
=
self
.
update_model_input
(
model_input
,
outputs
[
-
1
])
return
outputs
vllm/spec_decode/multi_step_worker.py
View file @
b2c62023
...
...
@@ -6,6 +6,7 @@ import torch
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeProposer
)
from
vllm.spec_decode.proposer_worker_base
import
ProposerWorkerBase
...
...
@@ -67,22 +68,24 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
copied_execute_model_req
=
execute_model_req
.
clone
(
copied_seq_group_metadata_list
)
# Assert enough KV space for sample_len tokens per sequence.
self
.
_assert_enough_kv_space
(
execute_model_req
.
seq_group_metadata_list
,
sample_len
)
# Run model sample_len times.
model_outputs
:
List
[
SamplerOutput
]
=
[]
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
super
().
execute_model
(
if
isinstance
(
self
.
model_runner
,
TP1DraftModelRunner
):
copied_execute_model_req
.
num_steps
=
sample_len
model_outputs
=
self
.
execute_model
(
execute_model_req
=
copied_execute_model_req
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
model_outputs
.
append
(
model_output
)
else
:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
for
_
in
range
(
sample_len
):
model_output
:
List
[
SamplerOutput
]
=
super
().
execute_model
(
execute_model_req
=
copied_execute_model_req
)
assert
(
len
(
model_output
)
==
1
),
"composing multistep workers not supported"
model_output
=
model_output
[
0
]
self
.
_append_new_tokens
(
model_output
,
copied_seq_group_metadata_list
)
model_outputs
.
append
(
model_output
)
return
model_outputs
,
True
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
b2c62023
...
...
@@ -11,6 +11,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates
,
SamplerOutput
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.batch_expansion
import
BatchExpansionTop1Scorer
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.metrics
import
AsyncMetricsCollector
...
...
@@ -117,6 +118,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_tp
=
draft_parallel_config
.
tensor_parallel_size
target_tp
=
scorer_worker
.
parallel_config
.
tensor_parallel_size
if
draft_tp
==
1
:
draft_worker_kwargs
[
"model_runner_cls"
]
=
TP1DraftModelRunner
proposer_worker
=
MultiStepWorker
(
**
draft_worker_kwargs
)
proposer_worker
=
SmallerTpProposerWorker
.
maybe_wrap_worker
(
proposer_worker
,
draft_tp
,
target_tp
)
...
...
vllm/worker/cpu_model_runner.py
View file @
b2c62023
...
...
@@ -351,7 +351,12 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self
,
model_input
:
CPUModelInput
,
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"CPU worker does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
...
...
@@ -371,11 +376,11 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
None
return
[]
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
output
return
[
output
]
vllm/worker/embedding_model_runner.py
View file @
b2c62023
...
...
@@ -57,7 +57,12 @@ class EmbeddingModelRunner(
self
,
model_input
:
ModelInputForGPUWithPoolingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
PoolerOutput
]:
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
PoolerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"EmbeddingModelRunner does not support multi-step execution."
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
...
...
@@ -91,10 +96,12 @@ class EmbeddingModelRunner(
# Only perform pooling in the driver worker.
if
not
self
.
is_driver_worker
:
return
None
return
[]
return
self
.
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
model_input
.
pooling_metadata
)
return
[
self
.
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
model_input
.
pooling_metadata
)
]
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
...
...
vllm/worker/model_runner.py
View file @
b2c62023
...
...
@@ -959,7 +959,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
kv_caches
:
List
[
torch
.
Tensor
],
)
->
SamplerOutput
:
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"num_steps > 1 is not supported in ModelRunner"
)
if
self
.
lora_config
:
assert
model_input
.
lora_requests
is
not
None
assert
model_input
.
lora_mapping
is
not
None
...
...
@@ -992,7 +996,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
None
return
[]
# Sample the next token.
output
:
SamplerOutput
=
self
.
model
.
sample
(
...
...
@@ -1011,7 +1015,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output
.
hidden_states
=
hidden_states
return
output
return
[
output
]
class
CUDAGraphRunner
:
...
...
vllm/worker/model_runner_base.py
View file @
b2c62023
...
...
@@ -150,7 +150,8 @@ class ModelRunnerBase(ABC, Generic[T]):
self
,
model_input
:
T
,
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]],
)
->
Optional
[
SamplerOutput
]:
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
"""
Execute the model on the given input.
"""
...
...
vllm/worker/neuron_model_runner.py
View file @
b2c62023
...
...
@@ -207,7 +207,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self
,
model_input
:
ModelInputForNeuron
,
kv_caches
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"NeuronModelRunner does not support multi-step execution."
)
hidden_states
=
self
.
model
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
@@ -223,7 +228,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
output
return
[
output
]
@
property
def
vocab_size
(
self
)
->
int
:
...
...
vllm/worker/tpu_model_runner.py
View file @
b2c62023
...
...
@@ -444,7 +444,12 @@ class TPUModelRunner:
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
num_steps
:
int
=
1
,
)
->
List
[
SamplerOutput
]:
if
num_steps
>
1
:
raise
ValueError
(
"TPUModelRunner does not support multi-step execution."
)
assert
seq_group_metadata_list
is
not
None
assert
len
(
seq_group_metadata_list
)
>
0
if
seq_group_metadata_list
[
0
].
is_prompt
:
...
...
@@ -462,7 +467,7 @@ class TPUModelRunner:
else
:
sampler_outputs
=
self
.
_execute_model
(
seq_group_metadata_list
,
kv_caches
)
return
SamplerOutput
(
sampler_outputs
)
return
[
SamplerOutput
(
sampler_outputs
)
]
class
ModelWrapper
(
nn
.
Module
):
...
...
vllm/worker/worker.py
View file @
b2c62023
...
...
@@ -45,6 +45,7 @@ class Worker(LocalOrDistributedWorkerBase):
vision_language_config
:
Optional
[
VisionLanguageConfig
]
=
None
,
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
,
is_driver_worker
:
bool
=
False
,
model_runner_cls
:
Optional
[
Type
[
GPUModelRunnerBase
]]
=
None
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
...
...
@@ -78,7 +79,9 @@ class Worker(LocalOrDistributedWorkerBase):
"mlp_speculator"
)
else
{
"return_hidden_states"
:
True
}
ModelRunnerClass
:
Type
[
GPUModelRunnerBase
]
=
ModelRunner
if
self
.
model_config
.
embedding_mode
:
if
model_runner_cls
is
not
None
:
ModelRunnerClass
=
model_runner_cls
elif
self
.
model_config
.
embedding_mode
:
ModelRunnerClass
=
EmbeddingModelRunner
self
.
model_runner
:
GPUModelRunnerBase
=
ModelRunnerClass
(
model_config
,
...
...
vllm/worker/worker_base.py
View file @
b2c62023
...
...
@@ -228,11 +228,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input
:
ModelRunnerInputBase
=
(
self
.
model_runner
.
prepare_model_input
(
execute_model_req
.
seq_group_metadata_list
))
num_steps
=
execute_model_req
.
num_steps
if
self
.
do_metadata_broadcast
:
broadcast_data
=
worker_input
.
as_broadcastable_tensor_dict
()
broadcast_data
.
update
(
model_input
.
as_broadcastable_tensor_dict
())
broadcast_data
[
"num_steps"
]
=
num_steps
broadcast_tensor_dict
(
broadcast_data
,
src
=
0
)
else
:
assert
self
.
do_metadata_broadcast
...
...
@@ -240,6 +242,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if
not
broadcast_data
:
return
None
num_steps
=
broadcast_data
.
pop
(
"num_steps"
)
worker_input
=
WorkerInput
.
from_broadcasted_tensor_dict
(
broadcast_data
)
model_input
=
(
...
...
@@ -252,10 +255,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if
worker_input
.
num_seq_groups
==
0
:
return
[]
output
=
self
.
model_runner
.
execute_model
(
model_input
,
self
.
kv_cache
)
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return
[
output
]
return
self
.
model_runner
.
execute_model
(
model_input
,
self
.
kv_cache
,
num_steps
)
class
WorkerWrapperBase
:
...
...
vllm/worker/xpu_model_runner.py
View file @
b2c62023
...
...
@@ -334,7 +334,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self
,
model_input
:
ModelInputForXPU
,
kv_caches
:
List
[
torch
.
Tensor
],
)
->
Optional
[
SamplerOutput
]:
num_steps
:
int
=
1
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
num_steps
>
1
:
raise
ValueError
(
"XPUModelRunner does not support multi-step execution."
)
model_executable
=
self
.
model
execute_model_kwargs
=
{
"input_ids"
:
model_input
.
input_tokens
,
...
...
@@ -354,14 +359,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Only perform sampling in the driver worker.
if
not
self
.
is_driver_worker
:
return
None
return
[]
# Sample the next token.
output
=
self
.
model
.
sample
(
logits
=
logits
,
sampling_metadata
=
model_input
.
sampling_metadata
,
)
return
output
return
[
output
]
def
_prepare_prompt
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment