Unverified Commit dda48115 authored by Stephanie Wang's avatar Stephanie Wang Committed by GitHub
Browse files

[Core] Refactor Worker and ModelRunner to consolidate control plane communication (#5408)


Signed-off-by: default avatarStephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: default avatarStephanie <swang@anyscale.com>
Co-authored-by: default avatarStephanie <swang@anyscale.com>
parent 82079729
"""A CPU worker class.""" """A CPU worker class."""
from typing import Any, Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
...@@ -8,15 +8,15 @@ from vllm.attention import get_attn_backend ...@@ -8,15 +8,15 @@ from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (ensure_model_parallel_initialized,
ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -110,7 +110,7 @@ class CPUCacheEngine: ...@@ -110,7 +110,7 @@ class CPUCacheEngine:
return dtype_size * total return dtype_size * total
class CPUWorker(LoraNotSupportedWorkerBase): class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket. """A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is Each worker is associated with a single CPU socket. The worker is
...@@ -154,7 +154,7 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -154,7 +154,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner = CPUModelRunner( self.model_runner: CPUModelRunner = CPUModelRunner(
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
...@@ -255,54 +255,37 @@ class CPUWorker(LoraNotSupportedWorkerBase): ...@@ -255,54 +255,37 @@ class CPUWorker(LoraNotSupportedWorkerBase):
for layer_cache in self.cpu_cache: for layer_cache in self.cpu_cache:
layer_cache.fill_(0) layer_cache.fill_(0)
def cache_copy( @property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
return self.cpu_cache
def execute_worker(
self, self,
blocks_to_copy: torch.Tensor, worker_input: WorkerInput,
) -> None: ) -> None:
if blocks_to_copy.numel() > 0: if (worker_input.blocks_to_copy is not None
self.cache_engine.copy(blocks_to_copy) and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def prepare_worker_input(
self, self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
execute_model_req: Optional[ExecuteModelRequest] = None, assert execute_model_req is not None
) -> List[SamplerOutput]: num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
if execute_model_req is None: blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
seq_group_metadata_list = None device="cpu",
else: dtype=torch.int64).view(-1, 2)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
if self.is_driver_worker: return WorkerInput(
assert seq_group_metadata_list is not None num_seq_groups=num_seq_groups,
num_seq_groups: int = len(seq_group_metadata_list) blocks_to_copy=blocks_to_copy,
assert execute_model_req is not None )
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": execute_model_req.blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache)
# CPU worker only supports single-step execution.
return [output]
def init_distributed_environment(self) -> None: def init_distributed_environment(self) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
......
from typing import Dict, List, Optional, Set, Tuple import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__) logger = init_logger(__name__)
class EmbeddingModelRunner(ModelRunner): @dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the EmbeddingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
def __init__( def __init__(
self, self,
...@@ -47,21 +55,22 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -47,21 +55,22 @@ class EmbeddingModelRunner(ModelRunner):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
(input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata assert model_input.attn_metadata is not None
decode_meta = attn_metadata.decode_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
...@@ -70,13 +79,14 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -70,13 +79,14 @@ class EmbeddingModelRunner(ModelRunner):
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": input_tokens, "input_ids": model_input.input_tokens,
"positions": input_positions, "positions": model_input.input_positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": attn_metadata, "attn_metadata": model_input.attn_metadata,
} }
if self.vision_language_config: if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input}) multi_modal_kwargs = model_input.multi_modal_kwargs or {}
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker. # Only perform pooling in the driver worker.
...@@ -84,66 +94,31 @@ class EmbeddingModelRunner(ModelRunner): ...@@ -84,66 +94,31 @@ class EmbeddingModelRunner(ModelRunner):
return None return None
return self.model.pooler(hidden_states=hidden_states, return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=pooling_metadata) pooling_metadata=model_input.pooling_metadata)
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForGPUWithPoolingMetadata:
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_input_tensors( def prepare_model_input(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, ) -> ModelInputForGPUWithPoolingMetadata:
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]: assert seq_group_metadata_list is not None
if self.is_driver_worker: model_input = self._prepare_model_input_tensors(
assert seq_group_metadata_list is not None seq_group_metadata_list)
# Prepare input tensors. # Prepare PoolingMetadata.
( assert model_input.seq_lens is not None
input_tokens, pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
input_positions, model_input.seq_lens)
attn_metadata,
seq_lens, return dataclasses.replace(model_input,
_, pooling_metadata=pooling_metadata)
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
seq_lens)
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
pooling_metadata = PoolingMetadata(seq_groups=None,
seq_data=None,
prompt_lens=None)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs)
def _prepare_pooling( def _prepare_pooling(
self, self,
......
import dataclasses
import gc import gc
import time import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -12,7 +14,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend ...@@ -12,7 +14,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
...@@ -26,6 +27,15 @@ from vllm.sampling_params import SamplingParams ...@@ -26,6 +27,15 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -39,40 +49,90 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ ...@@ -39,40 +49,90 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
] ]
_NUM_WARMUP_ITERS = 2 _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
class ModelInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: Optional[AttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
lora_mapping: Optional[LoRAMapping]
lora_requests: Set[LoRARequest]
multi_modal_kwargs: Dict[str, torch.Tensor]
slot_mapping: torch.Tensor
num_prefill_tokens: int
num_decode_tokens: int
num_prefills: int
@classmethod @dataclasses.dataclass(frozen=True)
def empty(cls, device): class ModelInputForGPU(ModelRunnerInputBase):
return ModelInput( """
input_tokens=torch.empty(0, device=device), This base class contains metadata needed for the base model forward pass
input_positions=torch.empty(0, device=device), but not metadata for possible additional steps, e.g., sampling. Model
attn_metadata=None, runners that run additional steps should subclass this method to add
seq_lens=[], additional fields.
query_lens=[], """
lora_mapping=None, input_tokens: Optional[torch.Tensor] = None
lora_requests=set(), input_positions: Optional[torch.Tensor] = None
multi_modal_kwargs={}, seq_lens: Optional[List[int]] = None
slot_mapping=torch.empty(0, device=device), query_lens: Optional[List[int]] = None
num_prefill_tokens=0, lora_mapping: Optional["LoRAMapping"] = None
num_decode_tokens=0, lora_requests: Optional[Set[LoRARequest]] = None
num_prefills=0, attn_metadata: Optional["AttentionMetadata"] = None
) multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForGPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForGPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
class ModelRunner: @classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForGPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
def __init__( def __init__(
self, self,
...@@ -241,11 +301,13 @@ class ModelRunner: ...@@ -241,11 +301,13 @@ class ModelRunner:
block_size = self.block_size block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_model_input( def _prepare_model_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInput: ) -> TModelInputForGPU:
"""Prepare the model input based on a given sequence group. """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
The API assumes seq_group_metadata_list is sorted by prefill -> decode. The API assumes seq_group_metadata_list is sorted by prefill -> decode.
...@@ -296,7 +358,7 @@ class ModelRunner: ...@@ -296,7 +358,7 @@ class ModelRunner:
paged_kv_last_page_len: List[int] = [] paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0: if len(seq_group_metadata_list) == 0:
return ModelInput.empty(self.device) return self._model_input_cls()
if self.sliding_window is not None: if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window + self.block_size - sliding_window_blocks = (self.sliding_window + self.block_size -
...@@ -646,7 +708,7 @@ class ModelRunner: ...@@ -646,7 +708,7 @@ class ModelRunner:
for k, v in multi_modal_kwargs_list.items() for k, v in multi_modal_kwargs_list.items()
} }
return ModelInput( return self._model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
...@@ -655,132 +717,8 @@ class ModelRunner: ...@@ -655,132 +717,8 @@ class ModelRunner:
lora_mapping=lora_mapping, lora_mapping=lora_mapping,
lora_requests=lora_requests, lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs, multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
)
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_kwargs)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
) )
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert seq_group_metadata_list is not None
if seq_group_metadata_list[0].is_prompt:
hidden_states = hidden_states.index_select(
0, sampling_metadata.selected_token_indices)
output.hidden_states = hidden_states
return output
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage. # Enable top-k sampling to reflect the accurate memory usage.
...@@ -853,7 +791,8 @@ class ModelRunner: ...@@ -853,7 +791,8 @@ class ModelRunner:
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches) model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
...@@ -986,6 +925,110 @@ class ModelRunner: ...@@ -986,6 +925,110 @@ class ModelRunner:
return self.model_config.get_vocab_size() return self.model_config.get_vocab_size()
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
GPU model runner with sampling step.
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForGPUWithSamplingMetadata:
return (
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
If cuda graph is required, this API automatically pads inputs.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
) -> SamplerOutput:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
assert model_input.sampling_metadata is not None
hidden_states = hidden_states.index_select(
0, model_input.sampling_metadata.selected_token_indices)
output.hidden_states = hidden_states
return output
class CUDAGraphRunner: class CUDAGraphRunner:
def __init__(self, model: nn.Module): def __init__(self, model: nn.Module):
......
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
T = TypeVar('T', bound="ModelRunnerInputBase")
def _add_attn_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
attn_metadata: Optional["AttentionMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
AttentionMetadata fields.
"""
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())
def _init_attn_metadata_from_tensor_dict(
attn_backend: "AttentionBackend",
tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper method to initialize AttentionMetadata based on an
AttentionBackend and broadcastable AttentionMetadata fields.
"""
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
return tensor_dict
def _init_sampling_metadata_from_tensor_dict( # type: ignore
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize SamplingMetadata based on broadcastable
SamplingMetadata fields.
"""
from vllm.model_executor import SamplingMetadata
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
# An empty SamplingMetadata to signal that the worker should skip
# sampling.
if selected_token_indices is not None:
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return tensor_dict
def _add_sampling_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
sampling_metadata: Optional["SamplingMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
SamplingMetadata fields.
"""
if sampling_metadata is not None:
tensor_dict["selected_token_indices"] = (
sampling_metadata.selected_token_indices)
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(ABC):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
custom deserialization.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def from_broadcasted_tensor_dict(
cls: Type[T],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
ModelRunnerInputBase.
"""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
model. Model execution may communicate data with model runners in other
processes, but it should not include control plane metadata communication.
Each ModelRunnerBase subclass should define a corresponding
ModelRunnerInputBase subclass.
"""
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> T:
"""
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
dict.
"""
raise NotImplementedError
@abstractmethod
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@torch.inference_mode()
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
) -> Optional[SamplerOutput]:
"""
Execute the model on the given input.
"""
raise NotImplementedError
from typing import List, Optional, Tuple from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -10,11 +11,39 @@ from vllm.model_executor import SamplingMetadata ...@@ -10,11 +11,39 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__) logger = init_logger(__name__)
class NeuronModelRunner: @dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForNeuron":
assert attn_backend is None
return cls.from_broadcasted_tensor_dict(tensor_dict)
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def __init__( def __init__(
self, self,
...@@ -139,10 +168,14 @@ class NeuronModelRunner: ...@@ -139,10 +168,14 @@ class NeuronModelRunner:
return input_tokens, input_positions, input_block_ids return input_tokens, input_positions, input_block_ids
def prepare_input_tensors( def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: ) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
...@@ -164,30 +197,31 @@ class NeuronModelRunner: ...@@ -164,30 +197,31 @@ class NeuronModelRunner:
self.device, self.device,
self.pin_memory) self.pin_memory)
return (input_tokens, input_positions, input_block_ids, return ModelInputForNeuron(input_tokens=input_tokens,
sampling_metadata) input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list)
hidden_states = self.model( hidden_states = self.model(
input_ids=input_tokens, input_ids=model_input.input_tokens,
positions=input_positions, positions=model_input.input_positions,
input_block_ids=input_block_ids, input_block_ids=model_input.input_block_ids,
) )
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
return output return output
......
"""A Neuron worker class.""" """A Neuron worker class."""
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.distributed import torch.distributed
...@@ -7,12 +7,13 @@ import torch.distributed ...@@ -7,12 +7,13 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
class NeuronWorker(LoraNotSupportedWorkerBase): class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores. """A worker class that executes the model on a group of neuron cores.
""" """
...@@ -34,8 +35,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase): ...@@ -34,8 +35,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
init_cached_hf_modules() init_cached_hf_modules()
self.model_runner = NeuronModelRunner(model_config, parallel_config, self.model_runner: NeuronModelRunner = NeuronModelRunner(
scheduler_config, device_config) model_config, parallel_config, scheduler_config, device_config)
self.is_driver_worker = True
def init_device(self) -> None: def init_device(self) -> None:
# Set random seed. # Set random seed.
...@@ -73,22 +75,19 @@ class NeuronWorker(LoraNotSupportedWorkerBase): ...@@ -73,22 +75,19 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks
@torch.inference_mode() @property
def execute_model( def do_metadata_broadcast(self) -> bool:
self, return False
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[SamplerOutput]:
num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model. @property
if num_seq_groups == 0: def kv_cache(self) -> Optional[List[torch.Tensor]]:
return [] return None
output = self.model_runner.execute_model(seq_group_metadata_list) @torch.inference_mode()
def prepare_worker_input(
# Neuron worker only supports single-step output. Wrap the output in a self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
# list to conform to interface. return WorkerInput(num_seq_groups=len(
return [output] execute_model_req.seq_group_metadata_list), )
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block. """Determine the size in bytes of a cache block.
......
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import Any, Dict, List, Optional, Set, Tuple, Union from typing import List, Optional, Set, Tuple, Type
import torch import torch
import torch.distributed import torch.distributed
...@@ -9,21 +9,20 @@ import torch.distributed ...@@ -9,21 +9,20 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig) SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (ensure_model_parallel_initialized,
ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
class Worker(WorkerBase): class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU. """A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for Each worker is associated with a single GPU. The worker is responsible for
...@@ -78,9 +77,10 @@ class Worker(WorkerBase): ...@@ -78,9 +77,10 @@ class Worker(WorkerBase):
or (speculative_config.draft_model_config.hf_config.model_type != or (speculative_config.draft_model_config.hf_config.model_type !=
"mlp_speculator") else {"return_hidden_states": True} "mlp_speculator") else {"return_hidden_states": True}
ModelRunnerClass = (EmbeddingModelRunner if ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
self.model_config.embedding_mode else ModelRunner) if self.model_config.embedding_mode:
self.model_runner = ModelRunnerClass( ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
...@@ -225,40 +225,18 @@ class Worker(WorkerBase): ...@@ -225,40 +225,18 @@ class Worker(WorkerBase):
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def cache_swap( @property
self, def do_metadata_broadcast(self) -> bool:
blocks_to_swap_in: torch.Tensor, return self.parallel_config.tensor_parallel_size > 1
blocks_to_swap_out: torch.Tensor,
blocks_to_copy: torch.Tensor, @property
) -> None: def kv_cache(self) -> Optional[List[torch.Tensor]]:
# Issue cache operations. return self.gpu_cache
if blocks_to_swap_in.numel() > 0:
self.cache_engine.swap_in(blocks_to_swap_in)
if blocks_to_swap_out.numel() > 0:
self.cache_engine.swap_out(blocks_to_swap_out)
if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def prepare_worker_input(
self, self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
execute_model_req: Optional[ExecuteModelRequest] = None num_seq_groups = len(execute_model_req.seq_group_metadata_list)
) -> List[Union[SamplerOutput, PoolerOutput]]:
if not self.is_driver_worker:
self._execute_model_non_driver()
return []
if execute_model_req is None:
# This signals that there's no more requests to process for now.
# All workers are running infinite loop with broadcast_tensor_dict,
# and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict({}, src=0)
return []
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
num_seq_groups = len(seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync. # they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
...@@ -273,59 +251,26 @@ class Worker(WorkerBase): ...@@ -273,59 +251,26 @@ class Worker(WorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device, device=self.device,
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list, return WorkerInput(
self.gpu_cache) num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
# Worker only supports single-step execution. Wrap the output in a list blocks_to_swap_out=blocks_to_swap_out,
# to conform to interface. blocks_to_copy=blocks_to_copy,
return [output] )
@torch.inference_mode() @torch.inference_mode()
def start_worker_execution_loop(self) -> None: def execute_worker(self, worker_input: WorkerInput) -> None:
"""Execute model loop in parallel worker. # Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
You can stop the loop by executing a driver worker with an empty output. and worker_input.blocks_to_swap_in.numel() > 0):
See `stop_remote_worker_execution_loop` for more details. self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
""" if (worker_input.blocks_to_swap_out is not None
while self._execute_model_non_driver(): and worker_input.blocks_to_swap_out.numel() > 0):
pass self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
def _execute_model_non_driver(self) -> bool: and worker_input.blocks_to_copy.numel() > 0):
"""Execute model in parallel worker. self.cache_engine.copy(worker_input.blocks_to_copy)
Returns True iff there are remaining sequences to process.
"""
assert not self.is_driver_worker
data = broadcast_tensor_dict(src=0)
if not data:
return False
num_seq_groups = data.get("num_seq_groups", 0)
blocks_to_swap_in = data.get("blocks_to_swap_in")
blocks_to_swap_out = data.get("blocks_to_swap_out")
blocks_to_copy = data.get("blocks_to_copy")
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return False
self.model_runner.execute_model(None, self.gpu_cache)
return True
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)
......
import dataclasses
import importlib import importlib
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread, is_hip, from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables) update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
logger = init_logger(__name__) logger = init_logger(__name__)
class WorkerBase(ABC): class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for """Worker interface that allows vLLM to cleanly separate implementations for
different hardware. different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
""" """
@abstractmethod @abstractmethod
...@@ -46,13 +52,23 @@ class WorkerBase(ABC): ...@@ -46,13 +52,23 @@ class WorkerBase(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
@abstractmethod @abstractmethod
def execute_model( def execute_model(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]: ) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
...@@ -98,6 +114,150 @@ class LoraNotSupportedWorkerBase(WorkerBase): ...@@ -98,6 +114,150 @@ class LoraNotSupportedWorkerBase(WorkerBase):
raise ValueError(f"{type(self)} does not support LoRA") raise ValueError(f"{type(self)} does not support LoRA")
@dataclasses.dataclass(frozen=True)
class WorkerInput:
"""Local inputs to each worker. May contain device-specific data. These
fields should be broadcastable to other workers.
"""
num_seq_groups: Optional[int] = None
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
tensor_dict: Dict[str, Any],
) -> "WorkerInput":
"""
Pop fields from the given tensor_dict and populate a new instance of
WorkerInput.
"""
return cls(
num_seq_groups=tensor_dict.pop("num_seq_groups"),
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
)
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
"""
Extract broadcastable fields.
"""
tensor_dict = {
"num_seq_groups": self.num_seq_groups,
"blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
}
return tensor_dict
class LocalOrDistributedWorkerBase(WorkerBase):
"""
Partial implementation of WorkerBase that has a default `execute_model`
definition to perform metadata transfer between workers when in distributed
mode. Subclasses of this interface should use model runners that inherit
from ModelRunnerBase, and should only need to implement worker-local logic.
If custom control plane logic is needed to transfer metadata, or if the
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
"""
is_driver_worker: bool
model_runner: ModelRunnerBase
@property
@abstractmethod
def do_metadata_broadcast(self) -> bool:
"""
Used by the default `execute_model` to check whether broadcast is
needed to transfer request inputs from the driver worker to other
workers in the TP group. If WorkerBase subclass only supports
single-worker execution, then this method should return False.
"""
raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]:
"""
Get the kv cache to pass to the worker's model runner. Used by the
default `execute_model`. If the worker's model runner does not follow
the ModelRunnerBase interface, then inherit from WorkerBase instead.
"""
raise NotImplementedError
@abstractmethod
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
"""
Prepare the inputs to WorkerBase.execute_worker from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@abstractmethod
def execute_worker(self, worker_input: WorkerInput) -> None:
"""
Process an execution request.
"""
raise NotImplementedError
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
assert self.do_metadata_broadcast
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
output = self.model_runner.execute_model(model_input, self.kv_cache)
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return [output]
class WorkerWrapperBase: class WorkerWrapperBase:
""" """
The whole point of this class is to lazily initialize the worker. The whole point of this class is to lazily initialize the worker.
......
from typing import List, Optional, Tuple from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -14,6 +15,15 @@ from vllm.sampling_params import SamplingParams ...@@ -14,6 +15,15 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,7 +34,42 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ ...@@ -24,7 +34,42 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
] ]
class XPUModelRunner: @dataclass(frozen=True)
class ModelInputForXPU(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForXPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForXPU":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
def __init__( def __init__(
self, self,
...@@ -130,15 +175,22 @@ class XPUModelRunner: ...@@ -130,15 +175,22 @@ class XPUModelRunner:
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches) model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches)
torch.xpu.synchronize() torch.xpu.synchronize()
return return
def prepare_input_tensors( def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU:
return (ModelInputForXPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> ModelInputForXPU:
Optional[torch.Tensor]]:
multi_modal_input = None multi_modal_input = None
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
...@@ -185,8 +237,11 @@ class XPUModelRunner: ...@@ -185,8 +237,11 @@ class XPUModelRunner:
num_prompts=0, num_prompts=0,
) )
return (input_tokens, input_positions, attn_metadata, return ModelInputForXPU(input_tokens=input_tokens,
sampling_metadata, multi_modal_input) input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
multi_modal_input=multi_modal_input)
def _prepare_decode( def _prepare_decode(
self, self,
...@@ -277,27 +332,25 @@ class XPUModelRunner: ...@@ -277,27 +332,25 @@ class XPUModelRunner:
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], model_input: ModelInputForXPU,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model model_executable = self.model
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": input_tokens, "input_ids": model_input.input_tokens,
"positions": input_positions, "positions": model_input.input_positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"attn_metadata": attn_metadata, "attn_metadata": model_input.attn_metadata,
} }
if self.vision_language_config: if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input}) execute_model_kwargs.update(
{"image_input": model_input.multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs) hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits. # Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
...@@ -306,7 +359,7 @@ class XPUModelRunner: ...@@ -306,7 +359,7 @@ class XPUModelRunner:
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=model_input.sampling_metadata,
) )
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment