Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
def load_column_parallel_weight(param: torch.nn.Parameter, def load_column_parallel_weight(param: torch.nn.Parameter,
...@@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
output_hidden_states = self.model( output_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel ...@@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision from .interfaces import SupportsVision
...@@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return None return None
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs: object): attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
...@@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module): ...@@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
...@@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput): ...@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass @dataclass
class SamplerOutput: class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object, """For each sequence group, we generate a list of SequenceOutput object,
...@@ -896,6 +924,8 @@ class ExecuteModelRequest: ...@@ -896,6 +924,8 @@ class ExecuteModelRequest:
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
# The number of requests in the running queue. # The number of requests in the running queue.
...@@ -914,6 +944,7 @@ class ExecuteModelRequest: ...@@ -914,6 +944,7 @@ class ExecuteModelRequest:
blocks_to_swap_in=self.blocks_to_swap_in.copy(), blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(), blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(), blocks_to_copy=self.blocks_to_copy.copy(),
virtual_engine=self.virtual_engine,
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states, previous_hidden_states=self.previous_hidden_states,
......
...@@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
...@@ -76,7 +77,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -76,7 +77,7 @@ class TP1DraftModelRunner(ModelRunner):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata: virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list """A temporary solution that caches the seq_group_metadata_list
for multi-step execution. for multi-step execution.
TODO: In-place update model_input and remove this function. TODO: In-place update model_input and remove this function.
...@@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner):
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore, # Since we do not broadcast data inside execute_model anymore,
...@@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = [] outputs: List[SamplerOutput] = []
for step in range(num_steps): for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
...@@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner):
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else: else:
model_executable = self.model model_executable = self.model
...@@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner):
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **multi_modal_kwargs,
) )
......
...@@ -38,7 +38,11 @@ class CacheEngine: ...@@ -38,7 +38,11 @@ class CacheEngine:
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype self.dtype = model_config.dtype
......
...@@ -13,7 +13,8 @@ from vllm.logger import init_logger ...@@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
...@@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> CPUModelInput: ) -> CPUModelInput:
multi_modal_kwargs = None multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
...@@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self, self,
model_input: CPUModelInput, model_input: CPUModelInput,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
......
...@@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CPUCacheEngine self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[torch.Tensor] self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None: def init_device(self) -> None:
self.init_distributed_environment() self.init_distributed_environment()
...@@ -242,17 +242,24 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -242,17 +242,24 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"initializing the engine.") "initializing the engine.")
def _init_cache_engine(self) -> None: def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config, self.cache_engine = [
self.model_config, CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.parallel_config, self.device_config)
self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size)
self.cpu_cache = self.cache_engine.cpu_cache ]
self.model_runner.block_size = self.cache_engine.block_size self.cpu_cache = [
self.cache_engine[ve].cpu_cache
assert self.cpu_cache is not None for ve in range(self.parallel_config.pipeline_parallel_size)
]
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory # Populate the cache to warmup the memory
for layer_cache in self.cpu_cache: for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0) layer_cache.fill_(0)
@property @property
...@@ -260,7 +267,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -260,7 +267,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1 return self.parallel_config.tensor_parallel_size > 1
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache return self.cpu_cache
def execute_worker( def execute_worker(
...@@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) -> None: ) -> None:
if (worker_input.blocks_to_copy is not None if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy) self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
...@@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
def init_distributed_environment(self) -> None: def init_distributed_environment(self) -> None:
......
...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -57,6 +58,7 @@ class EmbeddingModelRunner( ...@@ -57,6 +58,7 @@ class EmbeddingModelRunner(
self, self,
model_input: ModelInputForGPUWithPoolingMetadata, model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[PoolerOutput]]: ) -> Optional[List[PoolerOutput]]:
if num_steps > 1: if num_steps > 1:
...@@ -73,10 +75,12 @@ class EmbeddingModelRunner( ...@@ -73,10 +75,12 @@ class EmbeddingModelRunner(
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
...@@ -115,6 +119,7 @@ class EmbeddingModelRunner( ...@@ -115,6 +119,7 @@ class EmbeddingModelRunner(
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
) -> ModelInputForGPUWithPoolingMetadata: ) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
......
...@@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, ...@@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
import numpy as np import numpy as np
import torch import torch
import torch.distributed
import torch.nn as nn import torch.nn as nn
try: try:
...@@ -25,6 +26,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend ...@@ -25,6 +26,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -37,7 +39,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ...@@ -37,7 +39,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
...@@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_requests: Optional[Set[LoRARequest]] = None lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
virtual_engine: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
...@@ -89,6 +93,7 @@ class ModelInputForGPU(ModelRunnerInputBase): ...@@ -89,6 +93,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict return tensor_dict
...@@ -122,6 +127,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -122,6 +127,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict, _add_sampling_metadata_broadcastable_dict(tensor_dict,
...@@ -179,7 +185,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -179,7 +185,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window() self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture. int, int]] = None # Set during graph capture.
# When using CUDA graph, the input block tables must be padded to # When using CUDA graph, the input block tables must be padded to
...@@ -787,9 +796,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -787,9 +796,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = min( max_num_seqs = min(
max_num_seqs, max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size)) int(max_num_batched_tokens / vlm_config.image_feature_size))
batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs)) (group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \ seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len) .dummy_data_for_profiling(model_config, seq_len)
...@@ -811,7 +822,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -811,7 +822,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers kv_caches = [None] * num_layers
model_input = self.prepare_model_input(seqs) model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches) intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
...@@ -847,7 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -847,7 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return self.lora_manager.list_loras() return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[torch.Tensor]) -> None: def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number Note that CUDA graph's performance gain is negligible if number
...@@ -880,10 +897,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -880,10 +897,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
slot_mapping.fill_(_PAD_SLOT_ID) slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
batch_size=max_batch_size,
dtype=self.model_config.dtype,
device=self.device)
# Prepare buffer for outputs. These will be reused for all batch sizes. # Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture. # It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
None
] * self.parallel_config.pipeline_parallel_size
graph_batch_size = _get_graph_batch_size( graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
...@@ -912,13 +937,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -912,13 +937,17 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1] indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[:batch_size] last_page_len_buffer = last_page_len_buffer[:
batch_size]
num_qo_heads = self.model_config.get_num_attention_heads( num_qo_heads = (
self.parallel_config) self.model_config.get_num_attention_heads(
self.parallel_config))
num_kv_heads = self.model_config.get_num_kv_heads( num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config) self.parallel_config)
if num_qo_heads // num_kv_heads >= 4: if num_qo_heads // num_kv_heads >= 4:
...@@ -927,8 +956,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -927,8 +956,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
use_tensor_cores = False use_tensor_cores = False
decode_wrapper = \ decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer, indices_buffer, decode_workspace_buffer, indptr_buffer,
last_page_len_buffer, "NHD", use_tensor_cores) indices_buffer, last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype( kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype) self.kv_cache_dtype, self.model_config.dtype)
...@@ -990,8 +1020,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -990,8 +1020,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
) )
self.set_active_loras(set(), lora_mapping) self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model, graph_runner = CUDAGraphRunner(
self.attn_backend.get_name()) self.model, self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer": if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer graph_runner.flashinfer_indptr_buffer = indptr_buffer
...@@ -1006,15 +1036,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1006,15 +1036,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
graph_runner.capture( graph_runner.capture(
input_tokens[:batch_size], input_tokens[:batch_size],
input_positions[:batch_size], input_positions[:batch_size],
hidden_states[:batch_size] hidden_or_intermediate_states[
if hidden_states is not None else None, virtual_engine] # type: ignore
kv_caches, [:batch_size]
if hidden_or_intermediate_states[virtual_engine]
is not None else None,
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
kv_caches[virtual_engine],
attn_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream, stream=graph_capture_context.stream,
) )
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
end_time = time.perf_counter() end_time = time.perf_counter()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
...@@ -1047,6 +1083,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1047,6 +1083,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForGPUWithSamplingMetadata: ) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including """Prepare the model input based on a given sequence group, including
metadata for the sampling step. metadata for the sampling step.
...@@ -1072,15 +1109,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1072,15 +1109,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if seq_group_metadata_list else None) if seq_group_metadata_list else None)
return dataclasses.replace(model_input, return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
is_prompt=is_prompt) is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1: if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner") raise ValueError("num_steps > 1 is not supported in ModelRunner")
...@@ -1124,27 +1163,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1124,27 +1163,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata decode_meta = model_input.attn_metadata.decode_metadata
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **multi_modal_kwargs,
) )
# Compute the logits. # Compute the logits in the last pipeline stage.
logits = self.model.compute_logits(hidden_states, if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata) model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker: if not self.is_driver_worker:
return [] return []
...@@ -1159,9 +1205,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1159,9 +1205,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt: if model_input.is_prompt:
hidden_states = hidden_states.index_select(0, indices) hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
elif decode_meta.use_cuda_graph: elif decode_meta.use_cuda_graph:
hidden_states = hidden_states[:len(indices)] hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states
output.hidden_states = hidden_states output.hidden_states = hidden_states
...@@ -1195,13 +1244,15 @@ class CUDAGraphRunner: ...@@ -1195,13 +1244,15 @@ class CUDAGraphRunner:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: Optional[torch.Tensor], hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
torch.Tensor]],
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]], memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream, stream: torch.cuda.Stream,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
assert self._graph is None assert self._graph is None
# Run the model a few times without capturing the graph. # Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
...@@ -1213,6 +1264,7 @@ class CUDAGraphRunner: ...@@ -1213,6 +1264,7 @@ class CUDAGraphRunner:
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_inputs,
**kwargs, **kwargs,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1220,18 +1272,27 @@ class CUDAGraphRunner: ...@@ -1220,18 +1272,27 @@ class CUDAGraphRunner:
# Capture the graph. # Capture the graph.
self._graph = torch.cuda.CUDAGraph() self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_states = self.model( output_hidden_or_intermediate_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_inputs,
**kwargs, **kwargs,
) )
if hidden_states is not None: if hidden_or_intermediate_states is not None:
hidden_states.copy_(output_hidden_states) if get_pp_group().is_last_rank:
hidden_or_intermediate_states.copy_(
output_hidden_or_intermediate_states)
else:
for key in hidden_or_intermediate_states.tensors:
hidden_or_intermediate_states[key].copy_(
output_hidden_or_intermediate_states[key])
else: else:
hidden_states = output_hidden_states hidden_or_intermediate_states = (
del output_hidden_states output_hidden_or_intermediate_states)
del output_hidden_or_intermediate_states
# make sure `output_hidden_states` is deleted # make sure `output_hidden_states` is deleted
# in the graph's memory pool # in the graph's memory pool
gc.collect() gc.collect()
...@@ -1255,8 +1316,15 @@ class CUDAGraphRunner: ...@@ -1255,8 +1316,15 @@ class CUDAGraphRunner:
attn_metadata.decode_metadata.seq_lens_tensor, attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} if intermediate_inputs is not None:
return hidden_states self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
self.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
self.output_buffers = hidden_or_intermediate_states
return hidden_or_intermediate_states
def forward( def forward(
self, self,
...@@ -1264,6 +1332,7 @@ class CUDAGraphRunner: ...@@ -1264,6 +1332,7 @@ class CUDAGraphRunner:
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them. # KV caches are fixed tensors, so we don't need to copy them.
...@@ -1280,12 +1349,19 @@ class CUDAGraphRunner: ...@@ -1280,12 +1349,19 @@ class CUDAGraphRunner:
non_blocking=True) non_blocking=True)
self.input_buffers["block_tables"].copy_( self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
# Return the output tensor. # Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"] return self.output_buffers["hidden_states"]
return self.output_buffers
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
......
...@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, ...@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
...@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self, self,
model_input: T, model_input: T,
kv_caches: Optional[List[torch.Tensor]], kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
""" """
......
...@@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
...@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForNeuron: ) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self, self,
model_input: ModelInputForNeuron, model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None, kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
......
...@@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return False return False
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None return None
@torch.inference_mode() @torch.inference_mode()
......
...@@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if parallel_config and is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
...@@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase):
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CacheEngine self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[torch.tensor]] = None self.gpu_cache: Optional[List[List[torch.tensor]]] = None
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
...@@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase):
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = [
self.parallel_config, CacheEngine(self.cache_config, self.model_config,
self.device_config) self.parallel_config, self.device_config)
self.gpu_cache = self.cache_engine.gpu_cache for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
...@@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1 return self.parallel_config.tensor_parallel_size > 1
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache return self.gpu_cache
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list) num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync. # they contain parameters to launch cudamemcpyasync.
...@@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
@torch.inference_mode() @torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None: def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations. # Issue cache operations.
if (worker_input.blocks_to_swap_in is not None if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0): and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine.swap_in(worker_input.blocks_to_swap_in) self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0): and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine.swap_out(worker_input.blocks_to_swap_out) self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)
......
...@@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, is_hip, from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables) update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
...@@ -124,6 +125,7 @@ class WorkerInput: ...@@ -124,6 +125,7 @@ class WorkerInput:
blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
...@@ -139,6 +141,7 @@ class WorkerInput: ...@@ -139,6 +141,7 @@ class WorkerInput:
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
...@@ -151,6 +154,7 @@ class WorkerInput: ...@@ -151,6 +154,7 @@ class WorkerInput:
"blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
} }
return tensor_dict return tensor_dict
...@@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@property @property
@abstractmethod @abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
""" """
Get the kv cache to pass to the worker's model runner. Used by the Gets the list of kv caches to pass to the worker's model runner. Each
default `execute_model`. If the worker's model runner does not follow element in the list is a kv cache corresponding to a particular virtual
the ModelRunnerBase interface, then inherit from WorkerBase instead. engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list)) execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine))
num_steps = execute_model_req.num_steps num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
...@@ -255,9 +262,24 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -255,9 +262,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0: if worker_input.num_seq_groups == 0:
return [] return []
return self.model_runner.execute_model(model_input, self.kv_cache, intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps) num_steps)
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return output
class WorkerWrapperBase: class WorkerWrapperBase:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment