Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
......@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
def load_column_parallel_weight(param: torch.nn.Parameter,
......@@ -412,6 +412,7 @@ class Phi3SmallForCausalLM(nn.Module):
positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
output_hidden_states = self.model(
input_ids=input_ids,
......
......@@ -35,7 +35,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
......@@ -381,9 +381,13 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return None
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, **kwargs: object):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
......@@ -398,6 +402,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
......@@ -27,7 +27,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
......@@ -245,6 +245,7 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -45,7 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
......@@ -331,6 +331,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class Qwen2MoeMLP(nn.Module):
......@@ -397,6 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -41,7 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class StablelmMLP(nn.Module):
......@@ -250,6 +250,7 @@ class StablelmForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
class Starcoder2Attention(nn.Module):
......@@ -262,6 +262,7 @@ class Starcoder2ForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA
......@@ -320,6 +320,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
......
......@@ -770,6 +770,34 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return self.embeddings == other.embeddings
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
"""
tensors: Dict[str, torch.Tensor]
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
return self.tensors[key]
elif isinstance(key, slice):
return self.__class__({k: v[key] for k, v in self.tensors.items()})
def __setitem__(self, key: str, value):
self.tensors[key] = value
def __len__(self):
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"
@dataclass
class SamplerOutput:
"""For each sequence group, we generate a list of SequenceOutput object,
......@@ -896,6 +924,8 @@ class ExecuteModelRequest:
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# Virtual engine ID for pipeline parallel.
virtual_engine: int = 0
# The number of slots for lookahead decoding.
num_lookahead_slots: int = 0
# The number of requests in the running queue.
......@@ -914,6 +944,7 @@ class ExecuteModelRequest:
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(),
virtual_engine=self.virtual_engine,
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
......
......@@ -6,7 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
......@@ -74,9 +75,9 @@ class TP1DraftModelRunner(ModelRunner):
List[SequenceGroupMetadata]] = None
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
......@@ -115,6 +116,7 @@ class TP1DraftModelRunner(ModelRunner):
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
......@@ -130,6 +132,7 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
......@@ -139,7 +142,8 @@ class TP1DraftModelRunner(ModelRunner):
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else:
model_executable = self.model
......@@ -149,6 +153,7 @@ class TP1DraftModelRunner(ModelRunner):
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
)
......
......@@ -38,7 +38,11 @@ class CacheEngine:
self.block_size = cache_config.block_size
self.num_gpu_blocks = cache_config.num_gpu_blocks
if self.num_gpu_blocks:
self.num_gpu_blocks //= parallel_config.pipeline_parallel_size
self.num_cpu_blocks = cache_config.num_cpu_blocks
if self.num_cpu_blocks:
self.num_cpu_blocks //= parallel_config.pipeline_parallel_size
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
......
......@@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
......@@ -315,6 +316,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> CPUModelInput:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
......@@ -351,6 +353,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self,
model_input: CPUModelInput,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
......
......@@ -167,8 +167,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CPUCacheEngine
self.cpu_cache: List[torch.Tensor]
self.cache_engine: List[CPUCacheEngine]
self.cpu_cache: List[List[torch.Tensor]]
def init_device(self) -> None:
self.init_distributed_environment()
......@@ -242,25 +242,32 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"initializing the engine.")
def _init_cache_engine(self) -> None:
self.cache_engine = CPUCacheEngine(self.cache_config,
self.model_config,
self.parallel_config,
self.device_config)
self.cpu_cache = self.cache_engine.cpu_cache
self.model_runner.block_size = self.cache_engine.block_size
assert self.cpu_cache is not None
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
self.model_runner.block_size = self.cache_engine[0].block_size
assert all(
self.cpu_cache[ve] is not None
for ve in range(self.parallel_config.pipeline_parallel_size))
# Populate the cache to warmup the memory
for layer_cache in self.cpu_cache:
layer_cache.fill_(0)
for ve in range(self.parallel_config.pipeline_parallel_size):
for layer_cache in self.cpu_cache[ve]:
layer_cache.fill_(0)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.cpu_cache
def execute_worker(
......@@ -269,12 +276,14 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
) -> None:
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
self.cache_engine[worker_input.virtual_engine].copy(
worker_input.blocks_to_copy)
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
virtual_engine = execute_model_req.virtual_engine
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
......@@ -285,6 +294,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
def init_distributed_environment(self) -> None:
......
......@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__)
......@@ -57,6 +58,7 @@ class EmbeddingModelRunner(
self,
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[PoolerOutput]]:
if num_steps > 1:
......@@ -73,10 +75,12 @@ class EmbeddingModelRunner(
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
model_executable = self.model
......@@ -115,6 +119,7 @@ class EmbeddingModelRunner(
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
......
......@@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
try:
......@@ -25,6 +26,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
......@@ -37,7 +39,8 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
......@@ -81,6 +84,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
virtual_engine: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
......@@ -89,6 +93,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
......@@ -122,6 +127,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
......@@ -179,7 +185,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
# When using CUDA graph, the input block tables must be padded to
......@@ -787,9 +796,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_num_seqs = min(
max_num_seqs,
int(max_num_batched_tokens / vlm_config.image_feature_size))
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
batch_size += seq_len
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
.dummy_data_for_profiling(model_config, seq_len)
......@@ -811,7 +822,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return
......@@ -847,7 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
return self.lora_manager.list_loras()
@torch.inference_mode()
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
......@@ -880,10 +897,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
batch_size=max_batch_size,
dtype=self.model_config.dtype,
device=self.device)
# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None
hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
None
] * self.parallel_config.pipeline_parallel_size
graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
......@@ -912,109 +937,120 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[:batch_size]
num_qo_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
if num_qo_heads // num_kv_heads >= 4:
use_tensor_cores = True
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer":
indptr_buffer = indptr_buffer[:batch_size + 1]
last_page_len_buffer = last_page_len_buffer[:
batch_size]
num_qo_heads = (
self.model_config.get_num_attention_heads(
self.parallel_config))
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
if num_qo_heads // num_kv_heads >= 4:
use_tensor_cores = True
else:
use_tensor_cores = False
decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer,
indices_buffer, last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(
0, batch_size + 1, dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(
0, batch_size, dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full(
(batch_size, ), self.block_size, dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=
paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
else:
use_tensor_cores = False
decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, indptr_buffer, indices_buffer,
last_page_len_buffer, "NHD", use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(
0, batch_size + 1, dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(
0, batch_size, dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full(
(batch_size, ), self.block_size, dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=
paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
hidden_or_intermediate_states[
virtual_engine] # type: ignore
[:batch_size]
if hidden_or_intermediate_states[virtual_engine]
is not None else None,
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
kv_caches[virtual_engine],
attn_metadata,
memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream,
)
if self.lora_config:
lora_mapping = LoRAMapping(
[0] * batch_size,
[0] * batch_size,
)
self.set_active_loras(set(), lora_mapping)
graph_runner = CUDAGraphRunner(self.model,
self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
hidden_states[:batch_size]
if hidden_states is not None else None,
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
......@@ -1047,6 +1083,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
......@@ -1072,15 +1109,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt)
is_prompt=is_prompt,
virtual_engine=virtual_engine)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")
......@@ -1124,27 +1163,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable(
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return []
......@@ -1159,9 +1205,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt:
hidden_states = hidden_states.index_select(0, indices)
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
elif decode_meta.use_cuda_graph:
hidden_states = hidden_states[:len(indices)]
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states
output.hidden_states = hidden_states
......@@ -1195,13 +1244,15 @@ class CUDAGraphRunner:
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: Optional[torch.Tensor],
hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
torch.Tensor]],
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
assert self._graph is None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
......@@ -1213,6 +1264,7 @@ class CUDAGraphRunner:
positions,
kv_caches,
attn_metadata,
intermediate_inputs,
**kwargs,
)
torch.cuda.synchronize()
......@@ -1220,18 +1272,27 @@ class CUDAGraphRunner:
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_states = self.model(
output_hidden_or_intermediate_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_inputs,
**kwargs,
)
if hidden_states is not None:
hidden_states.copy_(output_hidden_states)
if hidden_or_intermediate_states is not None:
if get_pp_group().is_last_rank:
hidden_or_intermediate_states.copy_(
output_hidden_or_intermediate_states)
else:
for key in hidden_or_intermediate_states.tensors:
hidden_or_intermediate_states[key].copy_(
output_hidden_or_intermediate_states[key])
else:
hidden_states = output_hidden_states
del output_hidden_states
hidden_or_intermediate_states = (
output_hidden_or_intermediate_states)
del output_hidden_or_intermediate_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc.collect()
......@@ -1255,8 +1316,15 @@ class CUDAGraphRunner:
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return hidden_states
if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
self.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
self.output_buffers = hidden_or_intermediate_states
return hidden_or_intermediate_states
def forward(
self,
......@@ -1264,6 +1332,7 @@ class CUDAGraphRunner:
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs,
) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them.
......@@ -1280,11 +1349,18 @@ class CUDAGraphRunner:
non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
# Run the graph.
self.graph.replay()
# Return the output tensor.
return self.output_buffers["hidden_states"]
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
return self.output_buffers
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
......
......@@ -5,7 +5,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
......@@ -137,6 +138,7 @@ class ModelRunnerBase(ABC, Generic[T]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
......@@ -150,6 +152,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors],
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
"""
......
......@@ -9,7 +9,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......@@ -175,6 +176,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
......@@ -207,6 +209,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
......
......@@ -80,7 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return False
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return None
@torch.inference_mode()
......
......@@ -59,9 +59,9 @@ class Worker(LocalOrDistributedWorkerBase):
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
......@@ -99,9 +99,9 @@ class Worker(LocalOrDistributedWorkerBase):
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CacheEngine
self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[torch.tensor]] = None
self.gpu_cache: Optional[List[List[torch.tensor]]] = None
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
......@@ -217,10 +217,15 @@ class Worker(LocalOrDistributedWorkerBase):
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config,
self.device_config)
self.gpu_cache = self.cache_engine.gpu_cache
self.cache_engine = [
CacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.gpu_cache = [
self.cache_engine[ve].gpu_cache
for ve in range(self.parallel_config.pipeline_parallel_size)
]
def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
......@@ -234,12 +239,13 @@ class Worker(LocalOrDistributedWorkerBase):
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
return self.gpu_cache
@torch.inference_mode()
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
......@@ -261,20 +267,24 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
)
@torch.inference_mode()
def execute_worker(self, worker_input: WorkerInput) -> None:
virtual_engine = worker_input.virtual_engine
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
self.cache_engine[virtual_engine].swap_in(
worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
self.cache_engine[virtual_engine].swap_out(
worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
......
......@@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
......@@ -124,6 +125,7 @@ class WorkerInput:
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
@classmethod
def from_broadcasted_tensor_dict(
......@@ -139,6 +141,7 @@ class WorkerInput:
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
)
def as_broadcastable_tensor_dict(
......@@ -151,6 +154,7 @@ class WorkerInput:
"blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
}
return tensor_dict
......@@ -181,11 +185,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
@property
@abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]:
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Get the kv cache to pass to the worker's model runner. Used by the
default `execute_model`. If the worker's model runner does not follow
the ModelRunnerBase interface, then inherit from WorkerBase instead.
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
......@@ -227,7 +233,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine))
num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast:
......@@ -255,8 +262,23 @@ class LocalOrDistributedWorkerBase(WorkerBase):
if worker_input.num_seq_groups == 0:
return []
return self.model_runner.execute_model(model_input, self.kv_cache,
num_steps)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict())
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps)
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(output.tensors)
return [None]
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return output
class WorkerWrapperBase:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment