Unverified Commit 9a8853f7 authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[Core] Pipeline Parallel support for Model Runner V2 (#33960)


Signed-off-by: default avatarZhanqiu Hu <zh338@cornell.edu>
parent 387a1898
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import gc import gc
import time import time
from copy import deepcopy from copy import deepcopy
from typing import Any
import numpy as np import numpy as np
import torch import torch
...@@ -11,11 +10,15 @@ import torch.nn as nn ...@@ -11,11 +10,15 @@ import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import prepare_communication_buffer_for_model from vllm.distributed.parallel_state import (
get_pp_group,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
...@@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import ( ...@@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.pp_handler import PPHandler, get_pp_handler
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
...@@ -178,6 +182,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -178,6 +182,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured. # KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.pp_handler: PPHandler | None = (
get_pp_handler(self.parallel_config) if self.use_pp else None
)
def update_max_model_len(self, max_model_len: int) -> None: def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len self.req_states.max_model_len = max_model_len
...@@ -290,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -290,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output. # Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
...@@ -306,13 +316,31 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -306,13 +316,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Disable any use of KVConnector for dummy runs. # Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True) self.kv_connector.set_disabled(True)
# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors = None
if self.use_pp and not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
# Execute the model. # Execute the model.
self.execute_model( self.execute_model(
dummy_scheduler_output, dummy_run=True, skip_attn_for_dummy_run=skip_attn dummy_scheduler_output,
intermediate_tensors=intermediate_tensors,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
) )
self.kv_connector.set_disabled(False) self.kv_connector.set_disabled(False)
# Non-last PP ranks don't produce output for sampling.
if self.use_pp and not get_pp_group().is_last_rank:
return None, None
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, input_batch, _ = self.execute_model_state hidden_states, input_batch, _ = self.execute_model_state
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states return hidden_states, sample_hidden_states
...@@ -345,7 +373,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -345,7 +373,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states, sample_hidden_states = self._dummy_run( hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True self.max_num_tokens, skip_attn=True
) )
self._dummy_sampler_run(sample_hidden_states) # Only run sampler on last PP rank (non-last ranks return None).
if not self.use_pp or get_pp_group().is_last_rank:
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode: if self.do_spec_decode:
num_tokens_across_dp = make_num_tokens_across_dp( num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens self.parallel_config.data_parallel_size, self.max_num_tokens
...@@ -381,6 +412,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -381,6 +412,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
return 0 return 0
# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0
start_time = time.perf_counter() start_time = time.perf_counter()
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -801,11 +840,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -801,11 +840,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def execute_model( def execute_model(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
intermediate_tensors: Any | None = None, intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False, dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False, skip_attn_for_dummy_run: bool = False,
) -> ModelRunnerOutput | None: ) -> ModelRunnerOutput | IntermediateTensors | None:
assert intermediate_tensors is None
if not dummy_run: if not dummy_run:
# Update the request states. # Update the request states.
self.finish_requests(scheduler_output) self.finish_requests(scheduler_output)
...@@ -851,8 +889,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -851,8 +889,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
self._set_active_loras(*lora_inputs) self._set_active_loras(*lora_inputs)
if self.supports_mm_inputs: # Only first PP rank prepares multimodal embeddings.
# Execute the multimodal encoder. if self.supports_mm_inputs and (
not self.use_pp or get_pp_group().is_first_rank
):
mm_embeds, is_mm_embed = self.get_mm_embeddings( mm_embeds, is_mm_embed = self.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs, input_batch scheduler_output.scheduled_encoder_inputs, input_batch
) )
...@@ -894,6 +934,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -894,6 +934,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope: if self.uses_mrope:
assert input_batch.mrope_positions is not None assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions positions = input_batch.mrope_positions
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -904,27 +945,71 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -904,27 +945,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping=input_batch.slot_mappings, slot_mapping=input_batch.slot_mappings,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model( if self.use_pp and not get_pp_group().is_first_rank:
input_ids=input_batch.input_ids, # Non-first PP rank: forward with intermediate tensors.
positions=positions, assert intermediate_tensors is not None
inputs_embeds=input_batch.inputs_embeds, hidden_states = self.model(
) input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
kv_connector_output = self.kv_connector.post_forward(scheduler_output) kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = hidden_states, input_batch, kv_connector_output
if self.use_pp and not get_pp_group().is_last_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, input_batch, kv_connector_output)
return hidden_states
assert isinstance(hidden_states, torch.Tensor)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self.execute_model_state = (hidden_states, input_batch, kv_connector_output)
return None return None
@torch.inference_mode() @torch.inference_mode()
def sample_tokens( def sample_tokens(
self, grammar_output: GrammarOutput | None self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput: ) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state hidden_states, input_batch, kv_connector_output = self.execute_model_state
self.execute_model_state = None # type: ignore self.execute_model_state = None # type: ignore
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if self.use_pp and not get_pp_group().is_last_rank:
assert self.pp_handler is not None
received = self.pp_handler.maybe_receive_sampled_tokens(
input_batch.num_reqs,
self.device,
max_sample_len=self.num_speculative_steps + 1,
)
if received is not None:
sampled, num_sampled, num_rejected = received
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None
# Last rank: sample tokens
sampler_output, num_sampled, num_rejected = self.sample( sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, grammar_output hidden_states, input_batch, grammar_output
) )
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if self.use_pp:
assert self.pp_handler is not None
self.pp_handler.maybe_broadcast_sampled_tokens(
sampler_output, num_sampled, num_rejected
)
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits, self.model.compute_logits,
hidden_states, hidden_states,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism handler for V2 Model Runner."""
import torch
from vllm.distributed.parallel_state import get_pp_group
from vllm.v1.worker.gpu.sample.output import SamplerOutput
class PPHandler:
"""Pipeline parallelism handler for Model Runner V2.
Manages sampled token synchronization between PP ranks.
Only instantiated when PP is enabled (pp_size > 1).
"""
def maybe_broadcast_sampled_tokens(
self,
sampler_output: SamplerOutput,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
"""Broadcast sampled tokens from the last PP rank to all other ranks.
No-ops if this is not the last rank.
Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled
[num_reqs], and num_rejected [num_reqs] to support both regular decode
and speculative decoding.
Args:
sampler_output: SamplerOutput from sampling.
num_sampled: Number of accepted tokens per request.
num_rejected: Number of rejected tokens per request.
"""
pp = get_pp_group()
if not pp.is_last_rank:
return
torch.distributed.broadcast(
sampler_output.sampled_token_ids.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
torch.distributed.broadcast(
num_sampled.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
torch.distributed.broadcast(
num_rejected.contiguous(),
src=pp.last_rank,
group=pp.device_group,
)
def maybe_receive_sampled_tokens(
self,
num_reqs: int,
device: torch.device,
max_sample_len: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
"""Receive sampled tokens broadcast by the last PP rank.
Returns None if this is the last rank (which samples, not receives).
Args:
num_reqs: Number of requests in the batch.
device: Device to create tensors on.
max_sample_len: Maximum number of tokens sampled per request
(1 for regular decode, >1 for speculative decoding).
Returns:
None if called on last rank.
Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected):
- sampled_tokens: shape [num_reqs, max_sample_len]
- num_sampled: shape [num_reqs]
- num_rejected: shape [num_reqs]
"""
pp = get_pp_group()
if pp.is_last_rank:
return None
sampled_tokens = torch.empty(
num_reqs, max_sample_len, dtype=torch.int64, device=device
)
torch.distributed.broadcast(
sampled_tokens,
src=pp.last_rank,
group=pp.device_group,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled = torch.empty(num_reqs, dtype=torch.int32, device=device)
torch.distributed.broadcast(
num_sampled,
src=pp.last_rank,
group=pp.device_group,
)
num_rejected = torch.empty(num_reqs, dtype=torch.int32, device=device)
torch.distributed.broadcast(
num_rejected,
src=pp.last_rank,
group=pp.device_group,
)
return sampled_tokens, num_sampled, num_rejected
def get_pp_handler(parallel_config) -> PPHandler:
"""Factory function to create PPHandler.
Must only be called when PP is enabled (pp_size > 1).
"""
assert parallel_config.pipeline_parallel_size > 1, (
"PPHandler should not be created when pipeline parallelism is disabled."
)
return PPHandler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment