Unverified Commit dda48115 authored by Stephanie Wang's avatar Stephanie Wang Committed by GitHub
Browse files

[Core] Refactor Worker and ModelRunner to consolidate control plane communication (#5408)


Signed-off-by: default avatarStephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: default avatarStephanie <swang@anyscale.com>
Co-authored-by: default avatarStephanie <swang@anyscale.com>
parent 82079729
import dataclasses
from typing import List, Tuple, Type
import torch
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class MockAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
def get_impl_cls():
raise NotImplementedError
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
pass
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
pass
def test_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithSamplingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None
......@@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
slot_mapping = attn_metadata.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
......@@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
......@@ -259,32 +261,29 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_seq_lens) == 0
assert return_seq_lens is None
@pytest.fixture
......@@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
......@@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
attn_metadata = model_runner._prepare_model_input(
attn_metadata = model_runner._prepare_model_input_tensors(
seq_group_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
......
......@@ -21,9 +21,13 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_kv_cache_shape(
......
......@@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
return BlocksparseFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -22,8 +22,8 @@ class FlashInferBackend(AttentionBackend):
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -25,8 +25,8 @@ class IpexAttnBackend(AttentionBackend):
return IpexAttnBackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
return IpexAttnMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -16,8 +16,8 @@ class PallasAttentionBackend(AttentionBackend):
return PallasAttentionBackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "PallasMetadata":
return PallasMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
return ROCmFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -31,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
return TorchSDPABackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -28,8 +28,8 @@ class XFormersBackend(AttentionBackend):
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata
@staticmethod
def get_kv_cache_shape(
......
......@@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
......@@ -79,7 +79,7 @@ class DistributedGPUExecutor(GPUExecutor):
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
......@@ -123,13 +123,13 @@ class DistributedGPUExecutor(GPUExecutor):
@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError
......
......@@ -69,8 +69,8 @@ class ExecutorBase(ABC):
@abstractmethod
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError
......
......@@ -87,7 +87,7 @@ class GPUExecutor(ExecutorBase):
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> List[Union[SamplerOutput, PoolerOutput]]:
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
......
......@@ -78,16 +78,14 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
worker_monitor.close()
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
return self.driver_worker.execute_model(execute_model_req)
def _run_workers(
self,
......
......@@ -55,8 +55,7 @@ class NeuronExecutor(ExecutorBase):
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model(
execute_model_req.seq_group_metadata_list)
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
......
......@@ -190,9 +190,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_parallel_loading_workers)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
......
......@@ -887,7 +887,8 @@ class HiddenStates:
@dataclass
class ExecuteModelRequest:
"""The model execution request."""
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. List of CPU -> GPU block number.
......
......@@ -7,7 +7,6 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.worker.model_runner import ModelInput
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
......@@ -56,7 +55,7 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, List[int], List[int]]:
if not seq_group_metadata_list:
return ModelInput.empty(self.device)
return torch.empty(0, device=self.device), [], []
input_tokens: List[int] = []
seq_lens: List[int] = []
......
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import nn
......@@ -8,20 +9,64 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
class CPUModelRunner:
@dataclass(frozen=True)
class CPUModelInput(ModelRunnerInputBase):
"""
Used by the CPUModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["CPUModelInput"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> "CPUModelInput":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def __init__(
self,
......@@ -270,86 +315,70 @@ class CPUModelRunner:
attn_metadata,
)
def prepare_input_tensors(
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> CPUModelInput:
return CPUModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[Dict[str, torch.Tensor]]]:
) -> CPUModelInput:
multi_modal_kwargs = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
seq_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_kwargs)
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: CPUModelInput,
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"attn_metadata": model_input.attn_metadata,
}
if self.vision_language_config and multi_modal_input is not None:
execute_model_kwargs.update(multi_modal_input)
if (self.vision_language_config
and model_input.multi_modal_kwargs is not None):
execute_model_kwargs.update(model_input.multi_modal_kwargs)
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
......@@ -358,6 +387,6 @@ class CPUModelRunner:
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
sampling_metadata=model_input.sampling_metadata,
)
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment