Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
def load_column_parallel_weight(param: torch.nn.Parameter, def load_column_parallel_weight(param: torch.nn.Parameter,
...@@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
output_hidden_states = self.model( output_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel ...@@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision from .interfaces import SupportsVision
...@@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return None return None
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs: object): attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
...@@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module): ...@@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
...@@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput): ...@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass @dataclass
class SamplerOutput: class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object, """For each sequence group, we generate a list of SequenceOutput object,
...@@ -896,6 +924,8 @@ class ExecuteModelRequest: ...@@ -896,6 +924,8 @@ class ExecuteModelRequest:
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
# The number of requests in the running queue. # The number of requests in the running queue.
...@@ -914,6 +944,7 @@ class ExecuteModelRequest: ...@@ -914,6 +944,7 @@ class ExecuteModelRequest:
blocks_to_swap_in=self.blocks_to_swap_in.copy(), blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(), blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(), blocks_to_copy=self.blocks_to_copy.copy(),
virtual_engine=self.virtual_engine,
num_lookahead_slots=self.num_lookahead_slots, num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size, running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states, previous_hidden_states=self.previous_hidden_states,
......
...@@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
...@@ -74,9 +75,9 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -74,9 +75,9 @@ class TP1DraftModelRunner(ModelRunner):
List[SequenceGroupMetadata]] = None List[SequenceGroupMetadata]] = None
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata: virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list """A temporary solution that caches the seq_group_metadata_list
for multi-step execution. for multi-step execution.
TODO: In-place update model_input and remove this function. TODO: In-place update model_input and remove this function.
...@@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner):
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore, # Since we do not broadcast data inside execute_model anymore,
...@@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests, self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping) model_input.lora_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = [] outputs: List[SamplerOutput] = []
for step in range(num_steps): for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
...@@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner):
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else: else:
model_executable = self.model model_executable = self.model
...@@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner):
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata, attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs, **multi_modal_kwargs,
) )
......
...@@ -38,7 +38,11 @@ class CacheEngine: ...@@ -38,7 +38,11 @@ class CacheEngine:
self.block_size = cache_config.block_size self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto": if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype self.dtype = model_config.dtype
......
...@@ -13,7 +13,8 @@ from vllm.logger import init_logger ...@@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
...@@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> CPUModelInput: ) -> CPUModelInput:
multi_modal_kwargs = None multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
...@@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self, self,
model_input: CPUModelInput, model_input: CPUModelInput,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
......
...@@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CPUCacheEngine self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[torch.Tensor] self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None: def init_device(self) -> None:
self.init_distributed_environment() self.init_distributed_environment()
...@@ -242,25 +242,32 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -242,25 +242,32 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"initializing the engine.") "initializing the engine.")
def _init_cache_engine(self) -> None: def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config, self.cache_engine = [
self.model_config, CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.parallel_config, self.device_config)
self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size)
self.cpu_cache = self.cache_engine.cpu_cache ]
self.model_runner.block_size = self.cache_engine.block_size self.cpu_cache = [
self.cache_engine[ve].cpu_cache
assert self.cpu_cache is not None for ve in range(self.parallel_config.pipeline_parallel_size)
]
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory # Populate the cache to warmup the memory
for layer_cache in self.cpu_cache: for ve in range(self.parallel_config.pipeline_parallel_size):
layer_cache.fill_(0) for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0)
@property @property
def do_metadata_broadcast(self) -> bool: def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1 return self.parallel_config.tensor_parallel_size > 1
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache return self.cpu_cache
def execute_worker( def execute_worker(
...@@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) -> None: ) -> None:
if (worker_input.blocks_to_copy is not None if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy) self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
...@@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput( return WorkerInput(
num_seq_groups=num_seq_groups, num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
def init_distributed_environment(self) -> None: def init_distributed_environment(self) -> None:
......
...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -57,6 +58,7 @@ class EmbeddingModelRunner( ...@@ -57,6 +58,7 @@ class EmbeddingModelRunner(
self, self,
model_input: ModelInputForGPUWithPoolingMetadata, model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[PoolerOutput]]: ) -> Optional[List[PoolerOutput]]:
if num_steps > 1: if num_steps > 1:
...@@ -73,10 +75,12 @@ class EmbeddingModelRunner( ...@@ -73,10 +75,12 @@ class EmbeddingModelRunner(
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else: else:
model_executable = self.model model_executable = self.model
...@@ -115,6 +119,7 @@ class EmbeddingModelRunner( ...@@ -115,6 +119,7 @@ class EmbeddingModelRunner(
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
) -> ModelInputForGPUWithPoolingMetadata: ) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
......
This diff is collapsed.
...@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, ...@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
...@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> T: ) -> T:
""" """
Prepare the inputs to ModelRunnerBase.execute_model from an execution Prepare the inputs to ModelRunnerBase.execute_model from an execution
...@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]): ...@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self, self,
model_input: T, model_input: T,
kv_caches: Optional[List[torch.Tensor]], kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
""" """
......
...@@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, ...@@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
...@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForNeuron: ) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
...@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self, self,
model_input: ModelInputForNeuron, model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None, kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
if num_steps > 1: if num_steps > 1:
......
...@@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return False return False
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None return None
@torch.inference_mode() @torch.inference_mode()
......
...@@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase):
self.lora_config = lora_config self.lora_config = lora_config
self.load_config = load_config self.load_config = load_config
self.is_driver_worker = is_driver_worker self.is_driver_worker = is_driver_worker
if self.is_driver_worker: if parallel_config and is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0." assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules from vllm.utils import init_cached_hf_modules
...@@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase):
) )
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# initialize_cache. # initialize_cache.
self.cache_engine: CacheEngine self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[torch.tensor]] = None self.gpu_cache: Optional[List[List[torch.tensor]]] = None
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
...@@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase):
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = [
self.parallel_config, CacheEngine(self.cache_config, self.model_config,
self.device_config) self.parallel_config, self.device_config)
self.gpu_cache = self.cache_engine.gpu_cache for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager: if not self.model_config.enforce_eager:
...@@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1 return self.parallel_config.tensor_parallel_size > 1
@property @property
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache return self.gpu_cache
@torch.inference_mode() @torch.inference_mode()
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list) num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync. # they contain parameters to launch cudamemcpyasync.
...@@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
) )
@torch.inference_mode() @torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None: def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations. # Issue cache operations.
if (worker_input.blocks_to_swap_in is not None if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0): and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine.swap_in(worker_input.blocks_to_swap_in) self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0): and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine.swap_out(worker_input.blocks_to_swap_out) self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0): and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy) self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)
......
...@@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union ...@@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from vllm.distributed import broadcast_tensor_dict from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, is_hip, from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables) update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
...@@ -124,6 +125,7 @@ class WorkerInput: ...@@ -124,6 +125,7 @@ class WorkerInput:
blocks_to_swap_in: Optional[torch.Tensor] = None blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
...@@ -139,6 +141,7 @@ class WorkerInput: ...@@ -139,6 +141,7 @@ class WorkerInput:
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
...@@ -151,6 +154,7 @@ class WorkerInput: ...@@ -151,6 +154,7 @@ class WorkerInput:
"blocks_to_swap_in": self.blocks_to_swap_in, "blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
} }
return tensor_dict return tensor_dict
...@@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@property @property
@abstractmethod @abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]: def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
""" """
Get the kv cache to pass to the worker's model runner. Used by the Gets the list of kv caches to pass to the worker's model runner. Each
default `execute_model`. If the worker's model runner does not follow element in the list is a kv cache corresponding to a particular virtual
the ModelRunnerBase interface, then inherit from WorkerBase instead. engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req=execute_model_req) execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = ( model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input( self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list)) execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine))
num_steps = execute_model_req.num_steps num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
...@@ -255,8 +262,23 @@ class LocalOrDistributedWorkerBase(WorkerBase): ...@@ -255,8 +262,23 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0: if worker_input.num_seq_groups == 0:
return [] return []
return self.model_runner.execute_model(model_input, self.kv_cache, intermediate_tensors = None
num_steps) if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps)
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return output
class WorkerWrapperBase: class WorkerWrapperBase:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment