Unverified Commit 9a8853f7 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[Core] Pipeline Parallel support for Model Runner V2 (#33960)


Signed-off-by: default avatarZhanqiu Hu <zh338@cornell.edu>
parent 387a1898
......@@ -3,7 +3,6 @@
import gc
import time
from copy import deepcopy
from typing import Any
import numpy as np
import torch
......@@ -11,11 +10,15 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import prepare_communication_buffer_for_model
from vllm.distributed.parallel_state import (
get_pp_group,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
......@@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler
......@@ -178,6 +182,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.pp_handler: PPHandler | None = (
get_pp_handler(self.parallel_config) if self.use_pp else None
)
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
......@@ -290,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
......@@ -306,13 +316,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True)
# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors = None
if self.use_pp and not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
# Execute the model.
self.execute_model(
dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn
dummy_scheduler_output,
intermediate_tensors=intermediate_tensors,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
)
self.kv_connector.set_disabled(False)
# Non-last PP ranks don't produce output for sampling.
if self.use_pp and not get_pp_group().is_last_rank:
return None, None
assert self.execute_model_state is not None
hidden_states, input_batch, _ = self.execute_model_state
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
......@@ -345,6 +373,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True
)
# Only run sampler on last PP rank (non-last ranks return None).
if not self.use_pp or get_pp_group().is_last_rank:
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode:
num_tokens_across_dp = make_num_tokens_across_dp(
......@@ -381,6 +412,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return 0
# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0
start_time = time.perf_counter()
gc.collect()
torch.cuda.empty_cache()
......@@ -801,11 +840,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: Any | None = None,
intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
) -> ModelRunnerOutput | None:
assert intermediate_tensors is None
) -> ModelRunnerOutput | IntermediateTensors | None:
if not dummy_run:
# Update the request states.
self.finish_requests(scheduler_output)
......@@ -851,8 +889,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self._set_active_loras(*lora_inputs)
if self.supports_mm_inputs:
# Execute the multimodal encoder.
# Only first PP rank prepares multimodal embeddings.
if self.supports_mm_inputs and (
not self.use_pp or get_pp_group().is_first_rank
):
mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch
)
......@@ -894,6 +934,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
......@@ -904,6 +945,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping=input_batch.slot_mappings,
):
self.kv_connector.pre_forward(scheduler_output)
if self.use_pp and not get_pp_group().is_first_rank:
# Non-first PP rank: forward with intermediate tensors.
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
......@@ -911,20 +962,54 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = hidden_states, input_batch, kv_connector_output
if self.use_pp and not get_pp_group().is_last_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, input_batch, kv_connector_output)
return hidden_states
assert isinstance(hidden_states, torch.Tensor)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self.execute_model_state = (hidden_states, input_batch, kv_connector_output)
return None
@torch.inference_mode()
def sample_tokens(
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput:
) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state
self.execute_model_state = None # type: ignore
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if self.use_pp and not get_pp_group().is_last_rank:
assert self.pp_handler is not None
received = self.pp_handler.maybe_receive_sampled_tokens(
input_batch.num_reqs,
self.device,
max_sample_len=self.num_speculative_steps + 1,
)
if received is not None:
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None
# Last rank: sample tokens
sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, grammar_output
)
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if self.use_pp:
assert self.pp_handler is not None
self.pp_handler.maybe_broadcast_sampled_tokens(
sampler_output, num_sampled, num_rejected
)
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits,
hidden_states,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism handler for V2 Model Runner."""
import torch
from vllm.distributed.parallel_state import get_pp_group
from vllm.v1.worker.gpu.sample.output import SamplerOutput
class PPHandler:
"""Pipeline parallelism handler for Model Runner V2.
Manages sampled token synchronization between PP ranks.
Only instantiated when PP is enabled (pp_size > 1).
"""
def maybe_broadcast_sampled_tokens(
self,
sampler_output: SamplerOutput,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
"""Broadcast sampled tokens from the last PP rank to all other ranks.
No-ops if this is not the last rank.
Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled
[num_reqs], and num_rejected [num_reqs] to support both regular decode
and speculative decoding.
Args:
sampler_output: SamplerOutput from sampling.
num_sampled: Number of accepted tokens per request.
num_rejected: Number of rejected tokens per request.
"""
pp = get_pp_group()
if not pp.is_last_rank:
return
torch.distributed.broadcast(
sampler_output.sampled_token_ids.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
torch.distributed.broadcast(
num_sampled.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
torch.distributed.broadcast(
num_rejected.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
def maybe_receive_sampled_tokens(
self,
num_reqs: int,
device: torch.device,
max_sample_len: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
"""Receive sampled tokens broadcast by the last PP rank.
Returns None if this is the last rank (which samples, not receives).
Args:
num_reqs: Number of requests in the batch.
device: Device to create tensors on.
max_sample_len: Maximum number of tokens sampled per request
(1 for regular decode, >1 for speculative decoding).
Returns:
None if called on last rank.
Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected):
- sampled_tokens: shape [num_reqs, max_sample_len]
- num_sampled: shape [num_reqs]
- num_rejected: shape [num_reqs]
"""
pp = get_pp_group()
if pp.is_last_rank:
return None
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=device
)
torch.distributed.broadcast(
sampled_tokens,
src=pp.last_rank,
group=pp.device_group,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device)
torch.distributed.broadcast(
num_sampled,
src=pp.last_rank,
group=pp.device_group,
)
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device)
torch.distributed.broadcast(
num_rejected,
src=pp.last_rank,
group=pp.device_group,
)
return sampled_tokens, num_sampled, num_rejected
def get_pp_handler(parallel_config) -> PPHandler:
"""Factory function to create PPHandler.
Must only be called when PP is enabled (pp_size > 1).
"""
assert parallel_config.pipeline_parallel_size > 1, (
"PPHandler should not be created when pipeline parallelism is disabled."
)
return PPHandler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment