Unverified Commit 3ed46f37 authored by Santino Ramos's avatar Santino Ramos Committed by GitHub
Browse files

[Model Runner V2] Add Support for XD-RoPE (#36817)


Signed-off-by: default avatarSantino Ramos <elsantinoramos@gmail.com>
parent 84868e47
...@@ -320,6 +320,9 @@ class ModelCudaGraphManager(CudaGraphManager): ...@@ -320,6 +320,9 @@ class ModelCudaGraphManager(CudaGraphManager):
model_inputs = { model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens], "input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens], "positions": input_buffers.positions[:num_tokens],
# TODO: Pass intermediate_tensors for PP CUDA graph
# support (https://github.com/vllm-project/vllm/pull/35162).
"intermediate_tensors": None,
**model_state.prepare_dummy_inputs(num_reqs, num_tokens), **model_state.prepare_dummy_inputs(num_reqs, num_tokens),
} }
model_output = model(**model_inputs) model_output = model(**model_inputs)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch import torch
import torch.nn as nn
from vllm.model_executor.models.interfaces import SupportsMRoPE from vllm.config import ModelConfig
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsXDRoPE
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
class MRopeState: class RopeState:
"""Unified state for multi-dimensional RoPE variants (M-RoPE, XD-RoPE).
M-RoPE: 3 dims, uses position delta for decode.
XD-RoPE: 3 or 4 dims, delta is 0 (decode uses orig_pos for all dims).
NOTE: `positions` is implemented with one additional dummy position on
purpose to make it non-contiguous so that it can work with torch compile.
See detailed explanation in
https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
NOTE: When M-RoPE is enabled, position ids are 3D regardless of the
modality of inputs. For text-only inputs, each dimension has identical
position IDs, making M-RoPE functionally equivalent to 1D-RoPE.
See page 5 of https://arxiv.org/abs/2409.12191
"""
def __init__( def __init__(
self, self,
num_dims: int,
has_delta: bool,
max_num_reqs: int, max_num_reqs: int,
max_num_tokens: int, max_num_tokens: int,
max_model_len: int, max_model_len: int,
device: torch.device, device: torch.device,
): ):
self.num_dims = num_dims
self.has_delta = has_delta
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -22,47 +46,51 @@ class MRopeState: ...@@ -22,47 +46,51 @@ class MRopeState:
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# wasting a lot of CPU memory. # wasting a lot of CPU memory.
self.prefill_mrope_positions = StagedWriteTensor( self.prefill_positions = StagedWriteTensor(
(max_num_reqs * 3, max_model_len), (max_num_reqs * num_dims, max_model_len),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
uva_instead_of_gpu=True, uva_instead_of_gpu=True,
) )
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32) self.positions = torch.zeros(
(num_dims, max_num_tokens + 1), dtype=torch.int64, device=device
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, max_num_tokens + 1), dtype=torch.int64, device=device
) )
def init_prefill_mrope_positions( # Delta is non-zero for M-RoPE, always 0 for XD-RoPE.
self.prefill_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
def init_prefill_positions(
self, self,
req_idx: int, req_idx: int,
mrope_model: SupportsMRoPE, model: nn.Module,
prefill_token_ids: list[int], prefill_token_ids: list[int],
mm_features: list, mm_features: list,
) -> None: ) -> None:
prefill_mrope_positions, prefill_mrope_delta = ( if self.has_delta:
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features) mrope_model = cast(SupportsMRoPE, model)
) prefill_positions, delta = mrope_model.get_mrope_input_positions(
for i in range(3): prefill_token_ids, mm_features
pos = prefill_mrope_positions[i].tolist() )
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos) self.prefill_delta.np[req_idx] = delta
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta else:
xdrope_model = cast(SupportsXDRoPE, model)
prefill_positions = xdrope_model.get_xdrope_input_positions(
prefill_token_ids, mm_features
)
for i in range(self.num_dims):
pos = prefill_positions[i].tolist()
self.prefill_positions.stage_write(self.num_dims * req_idx + i, 0, pos)
def apply_staged_writes(self) -> None: def apply_staged_writes(self) -> None:
self.prefill_mrope_positions.apply_write() self.prefill_positions.apply_write()
self.prefill_mrope_delta.copy_to_uva() if self.has_delta:
self.prefill_delta.copy_to_uva()
def get_positions(self, num_tokens: int) -> torch.Tensor:
return self.positions[:, :num_tokens]
def prepare_mrope_positions( def prepare_positions(
self, self,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
...@@ -70,34 +98,68 @@ class MRopeState: ...@@ -70,34 +98,68 @@ class MRopeState:
num_computed_tokens: torch.Tensor, num_computed_tokens: torch.Tensor,
) -> None: ) -> None:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
_prepare_mrope_positions_kernel[(num_reqs,)]( _prepare_rope_positions_kernel[(num_reqs,)](
self.mrope_positions, self.positions,
self.mrope_positions.stride(0), self.positions.stride(0),
self.prefill_mrope_positions.gpu, self.prefill_positions.gpu,
3 * self.max_model_len, self.num_dims * self.max_model_len,
self.max_model_len, self.max_model_len,
self.prefill_mrope_delta.gpu, self.prefill_delta.gpu,
idx_mapping, idx_mapping,
query_start_loc, query_start_loc,
prefill_lens, prefill_lens,
num_computed_tokens, num_computed_tokens,
BLOCK_SIZE=1024, BLOCK_SIZE=1024,
NUM_DIMS=self.num_dims,
) )
def get_rope_state(
model_config: ModelConfig,
model: nn.Module,
max_num_reqs: int,
max_num_tokens: int,
max_model_len: int,
device: torch.device,
) -> RopeState | None:
"""Create a RopeState if the model uses multi-dimensional RoPE."""
if model_config.uses_mrope:
assert isinstance(model, SupportsMRoPE)
return RopeState(
num_dims=3,
has_delta=True,
max_num_reqs=max_num_reqs,
max_num_tokens=max_num_tokens,
max_model_len=max_model_len,
device=device,
)
if model_config.uses_xdrope_dim > 0:
assert isinstance(model, SupportsXDRoPE)
return RopeState(
num_dims=model_config.uses_xdrope_dim,
has_delta=False,
max_num_reqs=max_num_reqs,
max_num_tokens=max_num_tokens,
max_model_len=max_model_len,
device=device,
)
return None
@triton.jit @triton.jit
def _prepare_mrope_positions_kernel( def _prepare_rope_positions_kernel(
mrope_positions_ptr, positions_ptr,
mrope_positions_stride, positions_stride,
prefill_mrope_positions_ptr, prefill_positions_ptr,
prefill_mrope_positions_stride0, prefill_positions_stride0,
prefill_mrope_positions_stride1, prefill_positions_stride1,
prefill_mrope_delta_ptr, prefill_delta_ptr,
idx_mapping_ptr, idx_mapping_ptr,
query_start_loc_ptr, query_start_loc_ptr,
prefill_lens_ptr, prefill_lens_ptr,
num_computed_tokens_ptr, num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
NUM_DIMS: tl.constexpr,
): ):
batch_idx = tl.program_id(0) batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
...@@ -110,27 +172,26 @@ def _prepare_mrope_positions_kernel( ...@@ -110,27 +172,26 @@ def _prepare_mrope_positions_kernel(
query_end = tl.load(query_start_loc_ptr + batch_idx + 1) query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start query_len = query_end - query_start
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx) delta = tl.load(prefill_delta_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE): for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len mask = block < query_len
orig_pos = num_computed + block orig_pos = num_computed + block
for j in tl.static_range(3): for j in tl.static_range(NUM_DIMS):
if is_prefill: if is_prefill:
# Read from pre-computed M-RoPE positions.
pos = tl.load( pos = tl.load(
prefill_mrope_positions_ptr prefill_positions_ptr
+ req_state_idx * prefill_mrope_positions_stride0 + req_state_idx * prefill_positions_stride0
+ j * prefill_mrope_positions_stride1 + j * prefill_positions_stride1
+ orig_pos, + orig_pos,
mask=mask, mask=mask,
) )
else: else:
# Apply M-RoPE delta. pos = orig_pos + delta
pos = orig_pos + mrope_delta
tl.store( tl.store(
mrope_positions_ptr + j * mrope_positions_stride + query_start + block, positions_ptr + j * positions_stride + query_start + block,
pos, pos,
mask=mask, mask=mask,
) )
...@@ -992,6 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -992,6 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"input_ids": input_batch.input_ids, "input_ids": input_batch.input_ids,
"positions": input_batch.positions, "positions": input_batch.positions,
"inputs_embeds": inputs_embeds, "inputs_embeds": inputs_embeds,
"intermediate_tensors": intermediate_tensors,
# NOTE: Values returned by `prepare_inputs` will override the default # NOTE: Values returned by `prepare_inputs` will override the default
# values above. # values above.
**self.model_state.prepare_inputs(input_batch, self.req_states), **self.model_state.prepare_inputs(input_batch, self.req_states),
...@@ -1000,7 +1001,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1000,7 +1001,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update for non-first PP ranks. # Update for non-first PP ranks.
model_inputs["input_ids"] = None model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None model_inputs["inputs_embeds"] = None
model_inputs["intermediate_tensors"] = intermediate_tensors assert intermediate_tensors is not None
# Run model. # Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL: if batch_desc.cg_mode == CUDAGraphMode.FULL:
......
...@@ -13,7 +13,7 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata ...@@ -13,7 +13,7 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.mm.rope import get_rope_state
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
...@@ -52,29 +52,28 @@ class DefaultModelState(ModelState): ...@@ -52,29 +52,28 @@ class DefaultModelState(ModelState):
device=self.device, device=self.device,
) )
self.uses_mrope = self.model_config.uses_mrope self.rope_state = get_rope_state(
if self.uses_mrope: self.model_config,
self.mrope_state = MRopeState( model,
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens, max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
device=self.device, device=self.device,
) )
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.uses_mrope: if self.rope_state is not None:
# Pre-compute M-RoPE positions for prefill.
assert new_req_data.prefill_token_ids is not None assert new_req_data.prefill_token_ids is not None
self.mrope_state.init_prefill_mrope_positions( self.rope_state.init_prefill_positions(
req_index, req_index,
self.model, # type: ignore self.model,
new_req_data.prefill_token_ids, new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features, mm_features=new_req_data.mm_features,
) )
def apply_staged_writes(self) -> None: def apply_staged_writes(self) -> None:
if self.uses_mrope: if self.rope_state is not None:
self.mrope_state.apply_staged_writes() self.rope_state.apply_staged_writes()
def get_mm_embeddings( def get_mm_embeddings(
self, self,
...@@ -109,31 +108,26 @@ class DefaultModelState(ModelState): ...@@ -109,31 +108,26 @@ class DefaultModelState(ModelState):
def prepare_inputs( def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, Any]: ) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope: if self.rope_state is None:
# Common case (1D positions). return {} # Common case (1D positions).
return {}
# Prepare M-RoPE positions. self.rope_state.prepare_positions(
self.mrope_state.prepare_mrope_positions(
input_batch.idx_mapping, input_batch.idx_mapping,
input_batch.query_start_loc, input_batch.query_start_loc,
req_states.prefill_len.gpu, req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu, req_states.num_computed_tokens.gpu,
) )
mrope_positions = self.mrope_state.mrope_positions[ positions = self.rope_state.get_positions(input_batch.num_tokens_after_padding)
:, : input_batch.num_tokens_after_padding return {"positions": positions}
]
return {"positions": mrope_positions}
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
model_inputs = {} model_inputs = {}
if self.supports_mm_inputs: if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens] inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds model_inputs["inputs_embeds"] = inputs_embeds
if self.uses_mrope: if self.rope_state is not None:
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] model_inputs["positions"] = self.rope_state.get_positions(num_tokens)
model_inputs["positions"] = mrope_positions
return model_inputs return model_inputs
def prepare_attn( def prepare_attn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment