Unverified Commit 63f49b8b authored by zhanqiuhu's avatar zhanqiuhu Committed by GitHub
Browse files

[Model Runner V2] Enable piecewise CUDA graphs for pipeline parallelism (#35162)


Signed-off-by: default avatarZhanqiu Hu <zh338@cornell.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent a5e9d511
...@@ -11,11 +11,16 @@ from tqdm import tqdm ...@@ -11,11 +11,16 @@ from tqdm import tqdm
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.distributed.parallel_state import (
get_pp_group,
graph_capture,
is_global_first_rank,
)
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
...@@ -87,7 +92,15 @@ class CudaGraphManager: ...@@ -87,7 +92,15 @@ class CudaGraphManager:
assert self.compilation_config is not None assert self.compilation_config is not None
self.cudagraph_mode = cudagraph_mode self.cudagraph_mode = cudagraph_mode
self.decode_query_len = decode_query_len self.decode_query_len = decode_query_len
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
if self.pp_size > 1:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {} self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
...@@ -267,12 +280,14 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -267,12 +280,14 @@ class ModelCudaGraphManager(CudaGraphManager):
self.hidden_states: torch.Tensor | None = None self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = [] self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
self.intermediate_tensors: IntermediateTensors | None = None
def capture( def capture(
self, self,
model: nn.Module, model: nn.Module,
model_state: ModelState, model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
intermediate_tensors: IntermediateTensors | None,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
...@@ -293,6 +308,19 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -293,6 +308,19 @@ class ModelCudaGraphManager(CudaGraphManager):
if self.dp_size > 1 if self.dp_size > 1
else None else None
) )
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
assert intermediate_tensors is not None
model_inputs["intermediate_tensors"] = intermediate_tensors[:num_tokens]
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
...@@ -318,21 +346,15 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -318,21 +346,15 @@ class ModelCudaGraphManager(CudaGraphManager):
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
): ):
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# TODO: Pass intermediate_tensors for PP CUDA graph
# support (https://github.com/vllm-project/vllm/pull/35162).
"intermediate_tensors": None,
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
model_output = model(**model_inputs) model_output = model(**model_inputs)
if cg_mode == CUDAGraphMode.PIECEWISE: if cg_mode == CUDAGraphMode.PIECEWISE:
# PW CUDA graph internally handles the model outputs. # PW CUDA graph internally handles the model outputs.
# No need to keep track of the hidden states. # No need to keep track of the hidden states.
return None return None
if self.is_last_pp_rank:
# Last PP rank (common case).
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
...@@ -340,13 +362,26 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -340,13 +362,26 @@ class ModelCudaGraphManager(CudaGraphManager):
aux_hidden_states = [] aux_hidden_states = []
if self.hidden_states is None: if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states) self.hidden_states = torch.empty_like(hidden_states)
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [ self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states torch.empty_like(x) for x in aux_hidden_states
] ]
self.hidden_states[:num_tokens] = hidden_states
for i, aux in enumerate(aux_hidden_states): for i, aux in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux self.aux_hidden_states[i][:num_tokens] = aux
else:
# Non-last PP rank.
intermediate_tensors = model_output
assert isinstance(intermediate_tensors, IntermediateTensors)
if self.intermediate_tensors is None:
self.intermediate_tensors = IntermediateTensors(
{
k: torch.empty_like(v)
for k, v in intermediate_tensors.tensors.items()
}
)
for k, v in intermediate_tensors.tensors.items():
self.intermediate_tensors[k][:num_tokens] = v
return forward_fn return forward_fn
...@@ -354,9 +389,13 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -354,9 +389,13 @@ class ModelCudaGraphManager(CudaGraphManager):
def run_fullgraph( def run_fullgraph(
self, desc: BatchExecutionDescriptor self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]] | IntermediateTensors:
"""Replay a captured FULL cudagraph and return hidden states.""" """Replay a captured FULL cudagraph and return hidden states."""
super().run_fullgraph(desc) super().run_fullgraph(desc)
if not self.is_last_pp_rank:
assert self.intermediate_tensors is not None
return self.intermediate_tensors[: desc.num_tokens]
assert self.hidden_states is not None assert self.hidden_states is not None
hidden_states = self.hidden_states[: desc.num_tokens] hidden_states = self.hidden_states[: desc.num_tokens]
if not self.use_aux_hidden_state_outputs: if not self.use_aux_hidden_state_outputs:
......
...@@ -140,6 +140,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -140,6 +140,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
self.is_first_pp_rank = True self.is_first_pp_rank = True
self.is_last_pp_rank = True self.is_last_pp_rank = True
# Persistent buffer for intermediate tensors (non-first PP ranks).
self.intermediate_tensors: IntermediateTensors | None = None
# Data parallelism. # Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size self.dp_size = self.parallel_config.data_parallel_size
...@@ -301,6 +303,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -301,6 +303,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.is_pooling_model and self.is_last_pp_rank: if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model) self.pooling_runner = PoolingRunner(self.model)
if not self.is_first_pp_rank:
# For non-first PP ranks, create intermediate tensors sized
# for the max capture size so they can be sliced per batch.
# Save as persistent member so runtime can copy received data
# into the same addresses that the CUDA graphs captured.
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
...@@ -396,14 +409,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -396,14 +409,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Disable any use of KVConnector for dummy runs. # Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True) self.kv_connector.set_disabled(True)
# For non-first PP ranks, create dummy intermediate_tensors. # Get the intermediate tensors for the dummy run.
intermediate_tensors = None intermediate_tensors = None
if not self.is_first_pp_rank: if not self.is_first_pp_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors( assert self.intermediate_tensors is not None
batch_size=num_tokens, intermediate_tensors = self.intermediate_tensors[:num_tokens]
dtype=self.model_config.dtype,
device=self.device,
)
# Execute the model. # Execute the model.
self.execute_model( self.execute_model(
...@@ -528,14 +538,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -528,14 +538,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
return 0 return 0
# TODO (zhanqiu): support CUDA graph for PP.
if self.use_pp:
logger.warning_once(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only.",
)
return 0
start_time = time.perf_counter() start_time = time.perf_counter()
gc.collect() gc.collect()
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
...@@ -546,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -546,6 +548,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model, self.model,
self.model_state, self.model_state,
self.input_buffers, self.input_buffers,
self.intermediate_tensors,
self.block_tables, self.block_tables,
self.attn_groups, self.attn_groups,
self.kv_cache_config, self.kv_cache_config,
...@@ -1010,7 +1013,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1010,7 +1013,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"input_ids": input_batch.input_ids, "input_ids": input_batch.input_ids,
"positions": input_batch.positions, "positions": input_batch.positions,
"inputs_embeds": inputs_embeds, "inputs_embeds": inputs_embeds,
"intermediate_tensors": intermediate_tensors,
# NOTE: Values returned by `prepare_inputs` will override the default # NOTE: Values returned by `prepare_inputs` will override the default
# values above. # values above.
**self.model_state.prepare_inputs(input_batch, self.req_states), **self.model_state.prepare_inputs(input_batch, self.req_states),
...@@ -1019,7 +1021,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1019,7 +1021,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update for non-first PP ranks. # Update for non-first PP ranks.
model_inputs["input_ids"] = None model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None model_inputs["inputs_embeds"] = None
# Prepare the intermediate tensors.
assert intermediate_tensors is not None assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
n = input_batch.num_tokens_after_padding
intermediate_tensors = IntermediateTensors(
{
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
for k, v in self.intermediate_tensors.tensors.items()
},
intermediate_tensors.kv_connector_output,
)
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model. # Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL: if batch_desc.cg_mode == CUDAGraphMode.FULL:
...@@ -1028,11 +1042,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1028,11 +1042,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# because they are already copied to the CUDA graph input buffers. # because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc) model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
else: else:
# For piecewise and eager mode, just call model(). # For piecewise and eager mode, just call model().
batch_descriptor = BatchDescriptor( batch_descriptor = BatchDescriptor(
...@@ -1052,11 +1061,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1052,11 +1061,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs) model_output = self.model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output if self.is_last_pp_rank:
else: if self.use_aux_hidden_state_outputs:
hidden_states = model_output assert isinstance(model_output, tuple)
aux_hidden_states = None hidden_states, aux_hidden_states = model_output
else:
assert isinstance(model_output, torch.Tensor)
hidden_states = model_output
aux_hidden_states = None
output_intermediate_tensors = None
else:
assert isinstance(model_output, IntermediateTensors)
hidden_states = None
aux_hidden_states = None
output_intermediate_tensors = model_output
kv_connector_output = self.kv_connector.post_forward(scheduler_output) kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = ExecuteModelState( self.execute_model_state = ExecuteModelState(
...@@ -1071,11 +1090,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1071,11 +1090,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending. # Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors) assert output_intermediate_tensors is not None
hidden_states.kv_connector_output = kv_connector_output output_intermediate_tensors.kv_connector_output = kv_connector_output
return hidden_states return output_intermediate_tensors
# Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor)
return None return None
@torch.inference_mode() @torch.inference_mode()
...@@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple): ...@@ -1259,7 +1276,7 @@ class ExecuteModelState(NamedTuple):
input_batch: InputBatch input_batch: InputBatch
attn_metadata: dict[str, Any] | None attn_metadata: dict[str, Any] | None
slot_mappings_by_layer: dict[str, torch.Tensor] | None slot_mappings_by_layer: dict[str, torch.Tensor] | None
hidden_states: torch.Tensor | IntermediateTensors hidden_states: torch.Tensor | None
aux_hidden_states: list[torch.Tensor] | None aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None kv_connector_output: KVConnectorOutput | None
num_tokens_across_dp: torch.Tensor | None num_tokens_across_dp: torch.Tensor | None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment