from typing import Dict, List, Optional, Set, Tuple import torch from vllm.attention import AttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import BatchType, ModelRunner logger = init_logger(__name__) class EmbeddingModelRunner(ModelRunner): def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, load_config: LoadConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, ): super().__init__(model_config, parallel_config, scheduler_config, device_config, cache_config, load_config, lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) # Currently cuda graph is only supported by the decode phase. prefill_meta = attn_metadata.prefill_metadata decode_meta = attn_metadata.decode_metadata if prefill_meta is None and decode_meta.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers execute_model_kwargs = { "input_ids": input_tokens, "positions": input_positions, "kv_caches": kv_caches, "attn_metadata": attn_metadata, } if self.vision_language_config: execute_model_kwargs.update({"image_input": multi_modal_input}) hidden_states = model_executable(**execute_model_kwargs) return self.model.pooler(hidden_states=hidden_states, pooling_metadata=pooling_metadata) def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] for seq_group_meta in seq_group_metadata_list: if seq_group_meta.is_prompt: prefill_reqs.append(seq_group_meta) else: decode_reqs.append(seq_group_meta) # Prepare input tensors. ( input_tokens, input_positions, prefill_attn_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, lora_requests, multi_modal_input, slot_mapping, ) = self._prepare_prompt(prefill_reqs) ( decode_input_tokens, decode_input_positions, decode_attn_metadata, decode_lora_index_mapping, decode_lora_prompt_mapping, decode_lora_requests, decode_slot_mapping, ) = self._prepare_decode(decode_reqs) # Prepare PoolingMetadata pooling_metadata = self._prepare_pooling(seq_group_metadata_list, prompt_lens) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 num_prefills = len(prompt_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) # Coalesce tensors. Note that attn_metadata is currently not # coalesced for simplicity. input_tokens.extend(decode_input_tokens) input_positions.extend(decode_input_positions) slot_mapping.extend(decode_slot_mapping) lora_index_mapping.extend(decode_lora_index_mapping) lora_prompt_mapping.extend(decode_lora_prompt_mapping) lora_requests.update(decode_lora_requests) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) if self.lora_config: lora_mapping = LoRAMapping( lora_index_mapping, lora_prompt_mapping, ) else: lora_mapping = None # Broadcast the metadata. # If batch contains both prefill and decode, it sends 2 broadcasts. # If it only contains 1 type, it triggers a single broadcast. if (prefill_attn_metadata is not None and decode_attn_metadata is not None): batch_type = BatchType.MIXED elif prefill_attn_metadata is not None: batch_type = BatchType.PREFILL else: batch_type = BatchType.DECODE metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, "lora_requests": lora_requests, "lora_mapping": lora_mapping, "multi_modal_input": multi_modal_input, "num_prefill_tokens": num_prefill_tokens, "num_decode_tokens": num_decode_tokens, "slot_mapping": slot_mapping, "num_prefills": num_prefills, "batch_type": batch_type, } if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) else: assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) # Broadcast decode attn metadata for mixed batch type. # The additional broadcast costs 300us overhead on 4 A10 GPUs. # We can potentially reduce the overhead by coelescing tensors. if batch_type == BatchType.MIXED: assert decode_attn_metadata is not None metadata_dict = decode_attn_metadata.asdict_zerocopy() broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict.pop("input_tokens") input_positions = metadata_dict.pop("input_positions") slot_mapping = metadata_dict.pop("slot_mapping") num_prefills = metadata_dict.pop("num_prefills") lora_mapping = metadata_dict.pop("lora_mapping") lora_requests = metadata_dict.pop("lora_requests") multi_modal_input = metadata_dict.pop("multi_modal_input") num_prefill_tokens = metadata_dict.pop("num_prefill_tokens") num_decode_tokens = metadata_dict.pop("num_decode_tokens") batch_type = metadata_dict.pop("batch_type") # Create an attention metadata. prefill_attn_metadata = None decode_attn_metadata = None if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED: prefill_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) else: decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) pooling_metadata = PoolingMetadata(seq_groups=None, seq_data=None, prompt_lens=None) # if it is a mixed batch, decode attn_metadata is broadcasted # separately. if batch_type == BatchType.MIXED: metadata_dict = broadcast_tensor_dict(src=0) decode_attn_metadata = self.attn_backend.make_metadata( **metadata_dict) attn_metadata = AttentionMetadata( num_prefills=num_prefills, slot_mapping=slot_mapping, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, prefill_metadata=prefill_attn_metadata, decode_metadata=decode_attn_metadata, ) return (input_tokens, input_positions, attn_metadata, pooling_metadata, lora_requests, lora_mapping, multi_modal_input) def _prepare_pooling( self, seq_group_metadata_list: List[SequenceGroupMetadata], prompt_lens: List[int], ) -> PoolingMetadata: """Prepare PoolingMetadata for the sequence group metadata list.""" seq_groups: List[Tuple[List[int], PoolingParams]] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) pooling_params = seq_group_metadata.pooling_params seq_groups.append((seq_ids, pooling_params)) seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) pooling_metadata = PoolingMetadata( seq_groups=seq_groups, seq_data=seq_data, prompt_lens=prompt_lens, ) return pooling_metadata