# SPDX-License-Identifier: Apache-2.0 from typing import List, Optional, Set import torch from neuronx_distributed_inference.modules.generation.sampling import ( prepare_sampling_params) from neuronx_distributed_inference.modules.lora_serving import ( LoraCheckpoint, LoraServingConfig) from vllm.config import VllmConfig from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuronx_distributed import ( _get_model_architecture, get_neuron_model) from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.worker.neuron_model_runner import (ModelInputForNeuron, NeuronModelRunner) logger = init_logger(__name__) class NeuronxDistributedModelRunner(NeuronModelRunner): def __init__( self, vllm_config: VllmConfig, ): super().__init__(vllm_config) self.lora_checkpoint = None self.model = None self.lora_serving_config = None @staticmethod def _get_lora_paths_strings(lora_modules: List[LoRAModulePath]): if not lora_modules: return None return {_.get("name"): _.get("path") for _ in lora_modules} def _get_nxdi_lora_config(self): override_neuron_config = self.model_config.override_neuron_config lora_modules = override_neuron_config.pop("lora_modules", None) target_modules = override_neuron_config.pop("target_modules", None) lora_ckpt_paths = self._get_lora_paths_strings(lora_modules) if self.lora_config.max_loras < len(lora_ckpt_paths): raise ValueError( "Number of LoRAs (%s) exceeds maximum " "allowed (%s)", len(lora_ckpt_paths), self.lora_config.max_loras) return LoraServingConfig( max_loras=self.lora_config.max_loras, max_lora_rank=self.lora_config.max_lora_rank, target_modules=target_modules, lora_ckpt_paths=lora_ckpt_paths, ) def load_model(self) -> None: # Update LoRA config if self.lora_config is not None: self.lora_serving_config = self._get_nxdi_lora_config() self.lora_checkpoint = LoraCheckpoint(self.lora_serving_config) self.model = get_neuron_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, lora_serving_config=self.lora_serving_config) def get_nxd_sampling_params(self, sampling_metadata): if self.model.config.neuron_config.on_device_sampling_config: max_topk = (self.model.config.neuron_config. on_device_sampling_config.global_topk) else: max_topk = self.model.config.vocab_size top_k = [1] * self.scheduler_config.max_num_seqs top_p = [1.0] * self.scheduler_config.max_num_seqs temperature = [1.0] * self.scheduler_config.max_num_seqs for index, sequenceGroupToSample in enumerate( sampling_metadata.seq_groups): top_k[index] = (sequenceGroupToSample.sampling_params.top_k if sequenceGroupToSample.sampling_params.top_k > 0 else max_topk) top_p[index] = sequenceGroupToSample.sampling_params.top_p temperature[index] = ( sequenceGroupToSample.sampling_params.temperature) sampling_params = prepare_sampling_params( batch_size=self.scheduler_config.max_num_seqs, top_k=top_k, top_p=top_p, temperature=temperature) return sampling_params def get_multi_modal_data_neuron(self, input_images): raise NotImplementedError("need to restore multi-modal support") @torch.inference_mode() def execute_model( self, model_input: ModelInputForNeuron, kv_caches: Optional[List[torch.Tensor]] = None, intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "NeuronModelRunner does not support multi-step execution.") if _get_model_architecture( self.model.config) != "MllamaForConditionalGeneration": return super().execute_model(model_input, kv_caches, intermediate_tensors, num_steps) sampling_params = self.get_nxd_sampling_params( model_input.sampling_metadata) if model_input.multi_modal_kwargs.get('image') is not None: pixel_values = [] aspect_ratios = [] num_chunks = [] has_image = [] for multi_modal_input in model_input.multi_modal_kwargs.get( 'image'): image_tensors = self.get_multi_modal_data_neuron( multi_modal_input.squeeze(0)) pixel_values.append(image_tensors[0]) aspect_ratios.append(image_tensors[1]) num_chunks.append(image_tensors[2]) has_image.append(image_tensors[3]) pixel_values = torch.cat(pixel_values, dim=0) aspect_ratios = torch.cat(aspect_ratios, dim=0) num_chunks = torch.cat(num_chunks, dim=0) has_image = torch.cat(has_image, dim=0) hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, seq_ids=model_input.input_block_ids, pixel_values=pixel_values, aspect_ratios=aspect_ratios, sampling_params=sampling_params, num_chunks=num_chunks, has_image=has_image, ) else: empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], dtype=torch.bfloat16) empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) num_chunks = torch.tensor([[1] ]) # dummy num_chunks, will not be used has_image = torch.tensor([0]) hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, seq_ids=model_input.input_block_ids, pixel_values=empty_pixel_values, aspect_ratios=empty_aspect_ratios, sampling_params=sampling_params, num_chunks=num_chunks, has_image=has_image, ) output = self.model.sample( hidden_states=hidden_states, sampling_metadata=model_input.sampling_metadata, ) return [output] def _get_lora_adapter_ids(self, seq_group_metadata_list): # set LoRA adapter IDs for multi-lora serving batch_size = len(seq_group_metadata_list) if self.lora_checkpoint is not None: # "0" indicates NxDI to use the base model for inference adapter_ids = ["0"] * batch_size for idx, seq_group_metadata in enumerate(seq_group_metadata_list): if seq_group_metadata.lora_request is not None: adapter_ids[ idx] = seq_group_metadata.lora_request.lora_name # convert adapter_ids from strings to integers adapter_ids = self.lora_checkpoint.convert_adapter_ids_to_indices( adapter_ids, batch_size) else: adapter_ids = torch.zeros((batch_size), dtype=torch.int32) return adapter_ids def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForNeuron: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, seq_lens, multi_modal_kwargs ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) seq_lens = None if not self._on_device_sampling_disabled: for seq_group_metadata in seq_group_metadata_list: sampling_params = seq_group_metadata.sampling_params top_k, top_p, temperature = ( self._convert_to_neuron_sampling_params(sampling_params)) sampling_params.top_k = top_k sampling_params.top_p = top_p sampling_params.temperature = temperature lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill # just use seq_lens instead. seq_lens, self.device, self.pin_memory, generators=self.get_generators(finished_requests_ids)) if current_platform.use_transformers_neuronx( ) and not self._on_device_sampling_disabled: # Once the request IDs are changed in current iteration, we will # update the on-device sampling parameters. current_batch_request_ids = [ seq_group_meta_data.request_id for seq_group_meta_data in seq_group_metadata_list ] if current_batch_request_ids != self._previous_batch_request_ids: self._update_neuron_sampling_params(seq_group_metadata_list) self._previous_batch_request_ids = current_batch_request_ids return ModelInputForNeuron(input_tokens=input_tokens, input_positions=input_positions, input_block_ids=input_block_ids, sampling_metadata=sampling_metadata, multi_modal_kwargs=multi_modal_kwargs, adapter_ids=lora_adapter_ids) def remove_all_loras(self): raise NotImplementedError( "Managing LoRAs is only supported through the " "lora_modules parameter in override_neuron_config") def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: raise NotImplementedError( "Managing LoRAs is only supported through the " "lora_modules parameter in override_neuron_config") def add_lora(self, lora_request: LoRARequest): logger.warning( "Adding LoRAs is only supported through the " "lora_modules parameter in override_neuron_config. If you supplied " "the parameter, you can ignore this warning. Ignoring" "lora request: ", lora_request) def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError( "Managing LoRAs is only supported through the " "lora_modules parameter in override_neuron_config") def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError( "Managing LoRAs is only supported through the " "lora_modules parameter in override_neuron_config") def list_loras(self) -> Set[int]: raise NotImplementedError( "Managing LoRAs is only supported through the " "lora_modules parameter in override_neuron_config")