"wrappers/python/vscode:/vscode.git/clone" did not exist on "2f74f96bbb0404185d35e6f03713e570acc4b9ff"
Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
......@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
def load_column_parallel_weight(param: torch.nn.Parameter,
......@@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module):
positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
output_hidden_states = self.model(
input_ids=input_ids,
......
......@@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
......@@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return None
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs: object):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
......@@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
......@@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
......@@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
......@@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class Qwen2MoeMLP(nn.Module):
......@@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class StablelmMLP(nn.Module):
......@@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class Starcoder2Attention(nn.Module):
......@@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA
......@@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass
class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object,
......@@ -896,6 +924,8 @@ class ExecuteModelRequest:
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding.
num_lookahead_slots: int = 0
# The number of requests in the running queue.
......@@ -914,6 +944,7 @@ class ExecuteModelRequest:
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(),
virtual_engine=self.virtual_engine,
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
......
......@@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
......@@ -74,9 +75,9 @@ class TP1DraftModelRunner(ModelRunner):
List[SequenceGroupMetadata]] = None
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
......@@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner):
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
......@@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
......@@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner):
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else:
model_executable = self.model
......@@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner):
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
)
......
......@@ -38,7 +38,11 @@ class CacheEngine:
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
......
......@@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
......@@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> CPUModelInput:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
......@@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self,
model_input: CPUModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
......
......@@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CPUCacheEngine
self.cpu_cache: List[torch.Tensor]
self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None:
self.init_distributed_environment()
......@@ -242,25 +242,32 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config,
self.model_config,
self.parallel_config,
self.device_config)
self.cpu_cache = self.cache_engine.cpu_cache
self.model_runner.block_size = self.cache_engine.block_size
assert self.cpu_cache is not None
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory
for layer_cache in self.cpu_cache:
layer_cache.fill_(0)
for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache
def execute_worker(
......@@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) -> None:
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
......@@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def init_distributed_environment(self) -> None:
......
......@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__)
......@@ -57,6 +58,7 @@ class EmbeddingModelRunner(
self,
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
......@@ -73,10 +75,12 @@ class EmbeddingModelRunner(
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
model_executable = self.model
......@@ -115,6 +119,7 @@ class EmbeddingModelRunner(
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
......
This diff is collapsed.
......@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
......@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
......@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
"""
......
......@@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
......@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
......
......@@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return False
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@torch.inference_mode()
......
......@@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase):
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
......@@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase):
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CacheEngine
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[torch.tensor]] = None
self.gpu_cache: Optional[List[List[torch.tensor]]] = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
......@@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase):
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config,
self.device_config)
self.gpu_cache = self.cache_engine.gpu_cache
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
......@@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
......@@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
......
......@@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......@@ -124,6 +125,7 @@ class WorkerInput:
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
@classmethod
def from_broadcasted_tensor_dict(
......@@ -139,6 +141,7 @@ class WorkerInput:
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
)
def as_broadcastable_tensor_dict(
......@@ -151,6 +154,7 @@ class WorkerInput:
"blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
}
return tensor_dict
......@@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@property
@abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Get the kv cache to pass to the worker's model runner. Used by the
default `execute_model`. If the worker's model runner does not follow
the ModelRunnerBase interface, then inherit from WorkerBase instead.
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
......@@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine))
num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast:
......@@ -255,8 +262,23 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(model_input, self.kv_cache,
num_steps)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps)
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return output
class WorkerWrapperBase:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment