Unverified Commit b2c62023 authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Spec Decode] Introduce DraftModelRunner (#5799)

parent b90d8cd8
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -85,6 +86,7 @@ def test_same_output_for_single_step(): ...@@ -85,6 +86,7 @@ def test_same_output_for_single_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
Worker, Worker,
...@@ -168,6 +170,7 @@ def test_same_output_for_multi_step(): ...@@ -168,6 +170,7 @@ def test_same_output_for_multi_step():
block_size, block_size,
num_gpu_blocks, num_gpu_blocks,
seed, seed,
model_runner_cls=TP1DraftModelRunner,
) )
worker = create_worker( worker = create_worker(
......
...@@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, ...@@ -14,6 +14,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput) SequenceOutput)
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
T = TypeVar("T", bound=Worker) T = TypeVar("T", bound=Worker)
...@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T], ...@@ -66,7 +67,8 @@ def create_worker(cls: Callable[..., T],
num_gpu_blocks: int, num_gpu_blocks: int,
seed: int, seed: int,
is_driver_worker: bool = True, is_driver_worker: bool = True,
enforce_eager: bool = True) -> T: enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
seed=seed, seed=seed,
...@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T], ...@@ -89,6 +91,7 @@ def create_worker(cls: Callable[..., T],
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker, is_driver_worker=is_driver_worker,
model_runner_cls=model_runner_cls,
) )
worker.init_device() worker.init_device()
......
...@@ -880,6 +880,8 @@ class ExecuteModelRequest: ...@@ -880,6 +880,8 @@ class ExecuteModelRequest:
running_queue_size: int = 0 running_queue_size: int = 0
# Optional hidden states from prior step. # Optional hidden states from prior step.
previous_hidden_states: Optional[HiddenStates] = None previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run.
num_steps: int = 1
def clone( def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata] self, seq_group_metadata_list: List[SequenceGroupMetadata]
...@@ -893,4 +895,5 @@ class ExecuteModelRequest: ...@@ -893,4 +895,5 @@ class ExecuteModelRequest:
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states, previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
) )
from typing import List, Optional
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
logger = init_logger(__name__)
class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
vision_language_config: Optional[VisionLanguageConfig] = None,
return_hidden_states: bool = False,
):
if return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)
super().__init__(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
cache_config=cache_config,
load_config=load_config,
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
return_hidden_states=return_hidden_states,
)
# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self.cached_seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(seq_group_metadata_list)
def update_model_input(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
"""
# Append the output token to the sequence data.
assert self.cached_seq_group_metadata_list is not None
for seq_group_metadata, sequence_group_outputs in zip(
self.cached_seq_group_metadata_list, last_output.outputs):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
return self.prepare_model_input(self.cached_seq_group_metadata_list)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Sample the next token.
outputs.append(
self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
))
# Prepare the inputs for the next step.
if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1])
return outputs
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
...@@ -67,22 +68,24 @@ class MultiStepWorker(Worker, ProposerWorkerBase): ...@@ -67,22 +68,24 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
copied_execute_model_req = execute_model_req.clone( copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list) copied_seq_group_metadata_list)
# Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
# Run model sample_len times. # Run model sample_len times.
model_outputs: List[SamplerOutput] = [] model_outputs: List[SamplerOutput] = []
for _ in range(sample_len): if isinstance(self.model_runner, TP1DraftModelRunner):
model_output: List[SamplerOutput] = super().execute_model( copied_execute_model_req.num_steps = sample_len
model_outputs = self.execute_model(
execute_model_req=copied_execute_model_req) execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1 else:
), "composing multistep workers not supported" # TODO: Remove this branch once DraftModelRunner supports TP>1.
model_output = model_output[0] for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
self._append_new_tokens(model_output, execute_model_req=copied_execute_model_req)
copied_seq_group_metadata_list) assert (len(model_output) == 1
model_outputs.append(model_output) ), "composing multistep workers not supported"
model_output = model_output[0]
self._append_new_tokens(model_output,
copied_seq_group_metadata_list)
model_outputs.append(model_output)
return model_outputs, True return model_outputs, True
......
...@@ -11,6 +11,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, ...@@ -11,6 +11,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
HiddenStates, SamplerOutput, SequenceGroupMetadata, HiddenStates, SamplerOutput, SequenceGroupMetadata,
get_all_seq_ids) get_all_seq_ids)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.metrics import AsyncMetricsCollector
...@@ -117,6 +118,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -117,6 +118,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
draft_tp = draft_parallel_config.tensor_parallel_size draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = scorer_worker.parallel_config.tensor_parallel_size target_tp = scorer_worker.parallel_config.tensor_parallel_size
if draft_tp == 1:
draft_worker_kwargs["model_runner_cls"] = TP1DraftModelRunner
proposer_worker = MultiStepWorker(**draft_worker_kwargs) proposer_worker = MultiStepWorker(**draft_worker_kwargs)
proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker(
proposer_worker, draft_tp, target_tp) proposer_worker, draft_tp, target_tp)
......
...@@ -351,7 +351,12 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -351,7 +351,12 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self, self,
model_input: CPUModelInput, model_input: CPUModelInput,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")
model_executable = self.model model_executable = self.model
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": model_input.input_tokens, "input_ids": model_input.input_tokens,
...@@ -371,11 +376,11 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -371,11 +376,11 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return None return []
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
return output return [output]
...@@ -57,7 +57,12 @@ class EmbeddingModelRunner( ...@@ -57,7 +57,12 @@ class EmbeddingModelRunner(
self, self,
model_input: ModelInputForGPUWithPoolingMetadata, model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]: num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
raise ValueError(
"EmbeddingModelRunner does not support multi-step execution.")
if self.lora_config: if self.lora_config:
assert model_input.lora_requests is not None assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None assert model_input.lora_mapping is not None
...@@ -91,10 +96,12 @@ class EmbeddingModelRunner( ...@@ -91,10 +96,12 @@ class EmbeddingModelRunner(
# Only perform pooling in the driver worker. # Only perform pooling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return None return []
return self.model.pooler(hidden_states=hidden_states, return [
pooling_metadata=model_input.pooling_metadata) self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
......
...@@ -959,7 +959,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -959,7 +959,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> SamplerOutput: num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
if self.lora_config: if self.lora_config:
assert model_input.lora_requests is not None assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None assert model_input.lora_mapping is not None
...@@ -992,7 +996,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -992,7 +996,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return None return []
# Sample the next token. # Sample the next token.
output: SamplerOutput = self.model.sample( output: SamplerOutput = self.model.sample(
...@@ -1011,7 +1015,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1011,7 +1015,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.hidden_states = hidden_states output.hidden_states = hidden_states
return output return [output]
class CUDAGraphRunner: class CUDAGraphRunner:
......
...@@ -150,7 +150,8 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -150,7 +150,8 @@ class ModelRunnerBase(ABC, Generic[T]):
self, self,
model_input: T, model_input: T,
kv_caches: Optional[List[torch.Tensor]], kv_caches: Optional[List[torch.Tensor]],
) -> Optional[SamplerOutput]: num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
""" """
Execute the model on the given input. Execute the model on the given input.
""" """
......
...@@ -207,7 +207,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -207,7 +207,12 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self, self,
model_input: ModelInputForNeuron, model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None, kv_caches: Optional[List[torch.Tensor]] = None,
) -> Optional[SamplerOutput]: num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
hidden_states = self.model( hidden_states = self.model(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
...@@ -223,7 +228,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -223,7 +228,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
return output return [output]
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
......
...@@ -444,7 +444,12 @@ class TPUModelRunner: ...@@ -444,7 +444,12 @@ class TPUModelRunner:
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput: num_steps: int = 1,
) -> List[SamplerOutput]:
if num_steps > 1:
raise ValueError(
"TPUModelRunner does not support multi-step execution.")
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
if seq_group_metadata_list[0].is_prompt: if seq_group_metadata_list[0].is_prompt:
...@@ -462,7 +467,7 @@ class TPUModelRunner: ...@@ -462,7 +467,7 @@ class TPUModelRunner:
else: else:
sampler_outputs = self._execute_model(seq_group_metadata_list, sampler_outputs = self._execute_model(seq_group_metadata_list,
kv_caches) kv_caches)
return SamplerOutput(sampler_outputs) return [SamplerOutput(sampler_outputs)]
class ModelWrapper(nn.Module): class ModelWrapper(nn.Module):
......
...@@ -45,6 +45,7 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -45,6 +45,7 @@ class Worker(LocalOrDistributedWorkerBase):
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None, speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
...@@ -78,7 +79,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -78,7 +79,9 @@ class Worker(LocalOrDistributedWorkerBase):
"mlp_speculator") else {"return_hidden_states": True} "mlp_speculator") else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if self.model_config.embedding_mode: if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self.model_config.embedding_mode:
ModelRunnerClass = EmbeddingModelRunner ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass( self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config, model_config,
......
...@@ -228,11 +228,13 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -228,11 +228,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list)) execute_model_req.seq_group_metadata_list))
num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict() broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update( broadcast_data.update(
model_input.as_broadcastable_tensor_dict()) model_input.as_broadcastable_tensor_dict())
broadcast_data["num_steps"] = num_steps
broadcast_tensor_dict(broadcast_data, src=0) broadcast_tensor_dict(broadcast_data, src=0)
else: else:
assert self.do_metadata_broadcast assert self.do_metadata_broadcast
...@@ -240,6 +242,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -240,6 +242,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if not broadcast_data: if not broadcast_data:
return None return None
num_steps = broadcast_data.pop("num_steps")
worker_input = WorkerInput.from_broadcasted_tensor_dict( worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data) broadcast_data)
model_input = ( model_input = (
...@@ -252,10 +255,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -252,10 +255,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0: if worker_input.num_seq_groups == 0:
return [] return []
output = self.model_runner.execute_model(model_input, self.kv_cache) return self.model_runner.execute_model(model_input, self.kv_cache,
# Worker only supports single-step execution. Wrap the output in a num_steps)
# list to conform to interface.
return [output]
class WorkerWrapperBase: class WorkerWrapperBase:
......
...@@ -334,7 +334,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -334,7 +334,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self, self,
model_input: ModelInputForXPU, model_input: ModelInputForXPU,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"XPUModelRunner does not support multi-step execution.")
model_executable = self.model model_executable = self.model
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": model_input.input_tokens, "input_ids": model_input.input_tokens,
...@@ -354,14 +359,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): ...@@ -354,14 +359,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return None return []
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
return output return [output]
def _prepare_prompt( def _prepare_prompt(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment