Unverified Commit c66aa48e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Add model states [1/N] (#35350)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent b6d5a172
...@@ -22,6 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -22,6 +22,7 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
...@@ -29,13 +30,11 @@ class CudaGraphManager: ...@@ -29,13 +30,11 @@ class CudaGraphManager:
def __init__( def __init__(
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
uses_mrope: bool,
use_aux_hidden_state_outputs: bool, use_aux_hidden_state_outputs: bool,
device: torch.device, device: torch.device,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device self.device = device
...@@ -88,8 +87,8 @@ class CudaGraphManager: ...@@ -88,8 +87,8 @@ class CudaGraphManager:
num_tokens: int, num_tokens: int,
capture_cg_mode: CUDAGraphMode, capture_cg_mode: CUDAGraphMode,
model: nn.Module, model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None, inputs_embeds: torch.Tensor | None,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
...@@ -113,13 +112,18 @@ class CudaGraphManager: ...@@ -113,13 +112,18 @@ class CudaGraphManager:
) )
else: else:
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens] model_inputs = {
if self.uses_mrope: "input_ids": input_buffers.input_ids[:num_tokens],
assert mrope_positions is not None "positions": input_buffers.positions[:num_tokens],
positions = mrope_positions[:, :num_tokens] "inputs_embeds": (
if inputs_embeds is not None: inputs_embeds[:num_tokens] if inputs_embeds is not None else None
inputs_embeds = inputs_embeds[:num_tokens] ),
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
...@@ -143,11 +147,7 @@ class CudaGraphManager: ...@@ -143,11 +147,7 @@ class CudaGraphManager:
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
): ):
model_output = model( model_output = model(**model_inputs)
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
...@@ -164,9 +164,7 @@ class CudaGraphManager: ...@@ -164,9 +164,7 @@ class CudaGraphManager:
num_tokens=num_tokens, num_tokens=num_tokens,
num_reqs=num_reqs, num_reqs=num_reqs,
model=model, model=model,
input_ids=input_ids, model_inputs=model_inputs,
positions=positions,
inputs_embeds=inputs_embeds,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
...@@ -178,9 +176,7 @@ class CudaGraphManager: ...@@ -178,9 +176,7 @@ class CudaGraphManager:
num_tokens: int, num_tokens: int,
num_reqs: int, num_reqs: int,
model: nn.Module, model: nn.Module,
input_ids: torch.Tensor, model_inputs: dict[str, torch.Tensor | None],
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor, num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None, attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None, slot_mappings: dict[str, torch.Tensor] | None,
...@@ -206,11 +202,8 @@ class CudaGraphManager: ...@@ -206,11 +202,8 @@ class CudaGraphManager:
), ),
torch.cuda.graph(graph, self.pool), torch.cuda.graph(graph, self.pool),
): ):
model_output = model( model_output = model(**model_inputs)
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
# Join offloader's copy stream after forward to avoid unjoined # Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream, # stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass. # but wait_prefetch only happens in the next forward pass.
...@@ -235,9 +228,7 @@ class CudaGraphManager: ...@@ -235,9 +228,7 @@ class CudaGraphManager:
num_tokens: int, num_tokens: int,
num_reqs: int, num_reqs: int,
model: nn.Module, model: nn.Module,
input_ids: torch.Tensor, model_inputs: dict[str, torch.Tensor | None],
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor, num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None, attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None, slot_mappings: dict[str, torch.Tensor] | None,
...@@ -256,18 +247,14 @@ class CudaGraphManager: ...@@ -256,18 +247,14 @@ class CudaGraphManager:
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
): ):
model( model(**model_inputs)
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
@torch.inference_mode() @torch.inference_mode()
def capture( def capture(
self, self,
model: nn.Module, model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None, inputs_embeds: torch.Tensor | None,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
...@@ -278,8 +265,8 @@ class CudaGraphManager: ...@@ -278,8 +265,8 @@ class CudaGraphManager:
device=self.device, device=self.device,
capture_fn=self.capture_graph, capture_fn=self.capture_graph,
model=model, model=model,
model_state=model_state,
input_buffers=input_buffers, input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
block_tables=block_tables, block_tables=block_tables,
attn_groups=attn_groups, attn_groups=attn_groups,
......
...@@ -65,8 +65,6 @@ class InputBatch: ...@@ -65,8 +65,6 @@ class InputBatch:
input_ids: torch.Tensor input_ids: torch.Tensor
# [num_tokens_after_padding] # [num_tokens_after_padding]
positions: torch.Tensor positions: torch.Tensor
# [3, num_tokens_after_padding]
mrope_positions: torch.Tensor | None
# [num_tokens_after_padding, hidden_size] # [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None inputs_embeds: torch.Tensor | None
...@@ -143,7 +141,6 @@ class InputBatch: ...@@ -143,7 +141,6 @@ class InputBatch:
seq_lens=seq_lens, seq_lens=seq_lens,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
mrope_positions=None,
inputs_embeds=None, inputs_embeds=None,
attn_metadata=None, # type: ignore attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore slot_mappings=None, # type: ignore
......
...@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import ( ...@@ -77,7 +77,7 @@ from vllm.v1.worker.gpu.kv_connector import (
) )
from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.lora_utils import LoraState
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.model_states import ModelState
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
...@@ -140,14 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -140,14 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_states = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
self.use_async_scheduling = self.scheduler_config.async_scheduling self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device) self.output_copy_stream = torch.cuda.Stream(self.device)
...@@ -212,7 +204,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -212,7 +204,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# CUDA graphs. # CUDA graphs.
self.cudagraph_manager = CudaGraphManager( self.cudagraph_manager = CudaGraphManager(
self.vllm_config, self.vllm_config,
self.uses_mrope,
self.use_aux_hidden_state_outputs, self.use_aux_hidden_state_outputs,
self.device, self.device,
) )
...@@ -271,6 +262,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -271,6 +262,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None: if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator) prepare_communication_buffer_for_model(self.speculator)
# Initialize the components that require the model.
self.model_state = ModelState(self.vllm_config, self.model, self.device)
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model return self.model
...@@ -481,16 +475,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -481,16 +475,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config): with self.maybe_setup_dummy_loras(self.lora_config):
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
inputs_embeds = None inputs_embeds = None
if self.supports_mm_inputs: if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds inputs_embeds = self.encoder_runner.inputs_embeds
self.cudagraph_manager.capture( self.cudagraph_manager.capture(
model=self.model, model=self.model,
model_state=self.model_state,
input_buffers=self.input_buffers, input_buffers=self.input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
block_tables=self.block_tables, block_tables=self.block_tables,
attn_groups=self.attn_groups, attn_groups=self.attn_groups,
...@@ -554,14 +545,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -554,14 +545,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.supports_mm_inputs: if self.supports_mm_inputs:
self.encoder_runner.add_request(req_id, new_req_data.mm_features) self.encoder_runner.add_request(req_id, new_req_data.mm_features)
# Pre-compute M-RoPE positions for prefill. self.model_state.add_request(req_index, new_req_data)
if self.uses_mrope:
self.mrope_states.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
self.block_tables.append_block_ids( self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True req_index, new_req_data.block_ids, overwrite=True
...@@ -577,8 +561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -577,8 +561,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if scheduler_output.scheduled_new_reqs: if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes() self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes() self.sampler.apply_staged_writes()
if self.uses_mrope: self.model_state.apply_staged_writes()
self.mrope_states.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None: def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
...@@ -692,15 +675,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -692,15 +675,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Prepare M-RoPE positions.
if self.uses_mrope:
self.mrope_states.prepare_mrope_positions(
idx_mapping,
query_start_loc,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens( logits_indices = combine_sampled_and_draft_tokens(
...@@ -744,10 +718,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -744,10 +718,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding] input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding] positions = self.input_buffers.positions[:num_tokens_after_padding]
mrope_positions = None
if self.uses_mrope:
mrope_positions = self.mrope_states.mrope_positions
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
...@@ -764,7 +734,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -764,7 +734,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
seq_lens=seq_lens, seq_lens=seq_lens,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
mrope_positions=mrope_positions,
inputs_embeds=None, inputs_embeds=None,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer, slot_mappings=slot_mappings_by_layer,
...@@ -959,14 +928,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -959,14 +928,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_buffers=self.input_buffers, input_buffers=self.input_buffers,
device=self.device, device=self.device,
) )
if self.uses_mrope:
input_batch.mrope_positions = self.mrope_states.mrope_positions[
:, :num_tokens_after_padding
]
if not skip_attn_for_dummy_run: if not skip_attn_for_dummy_run:
self.prepare_dummy_attn_metadata(input_batch) self.prepare_dummy_attn_metadata(input_batch)
# FIXME(woosuk): Fix warmup for LoRA. # FIXME(woosuk): Fix warmup for LoRA.
model_inputs = {
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": input_batch.inputs_embeds,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model. # Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL: if cudagraph_runtime_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode. # Use explicit cudagraph replay for FULL mode.
...@@ -983,20 +962,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -983,20 +962,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None aux_hidden_states = None
else: else:
# For piecewise and eager mode, just call model(). # For piecewise and eager mode, just call model().
positions = input_batch.positions
if self.uses_mrope:
assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions
if self.is_first_pp_rank:
input_ids = input_batch.input_ids
inputs_embeds = input_batch.inputs_embeds
assert intermediate_tensors is None
else:
input_ids = None
inputs_embeds = None
assert intermediate_tensors is not None
batch_descriptor = BatchDescriptor( batch_descriptor = BatchDescriptor(
num_tokens=input_batch.num_tokens_after_padding, num_tokens=input_batch.num_tokens_after_padding,
has_lora=self.lora_config is not None, has_lora=self.lora_config is not None,
...@@ -1012,12 +977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1012,12 +977,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping=input_batch.slot_mappings, slot_mapping=input_batch.slot_mappings,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.model( model_output = self.model(**model_inputs)
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.states import RequestState
class ModelState:
def __init__(self, vllm_config: VllmConfig, model: nn.Module, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.device = device
self.max_model_len = self.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_state = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.uses_mrope:
# Pre-compute M-RoPE positions for prefill.
assert new_req_data.prefill_token_ids is not None
self.mrope_state.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
def apply_staged_writes(self) -> None:
if self.uses_mrope:
self.mrope_state.apply_staged_writes()
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope:
# Common case (1D positions).
return {}
# Prepare M-RoPE positions.
self.mrope_state.prepare_mrope_positions(
input_batch.idx_mapping,
input_batch.query_start_loc,
req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu,
)
mrope_positions = self.mrope_state.mrope_positions[
:, : input_batch.num_tokens_after_padding
]
return {"positions": mrope_positions}
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope:
return {}
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
return {"positions": mrope_positions}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment