Unverified Commit c3b6559a authored by iefgnoix's avatar iefgnoix Committed by GitHub
Browse files

[V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU (#13379)


Signed-off-by: default avatarXiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 4be4b26c
......@@ -17,9 +17,8 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.7.0.dev20250226+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
......@@ -4,13 +4,16 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 128
class PallasAttentionBackend(AttentionBackend):
......@@ -47,47 +50,23 @@ class PallasAttentionBackend(AttentionBackend):
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: int
class PallasAttentionBackendImpl(AttentionImpl):
......@@ -105,10 +84,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
......@@ -126,25 +108,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
......@@ -164,135 +127,47 @@ class PallasAttentionBackendImpl(AttentionImpl):
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
[num_kv_heads, num_blocks, block_size, head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
if attn_metadata is None:
# For determine_available_memory case.
if kv_cache[0].numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
key_cache, value_cache = kv_cache
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size
output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=False,
)
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
......@@ -302,52 +177,21 @@ def write_to_kv_cache(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key = key.flatten(0, 1)
value = value.flatten(0, 1)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
......@@ -79,4 +79,4 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, LogprobsTensors]
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]]
......@@ -1071,12 +1071,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, LogprobsTensors]:
) -> Dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
......
# SPDX-License-Identifier: Apache-2.0
import enum
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from unittest.mock import patch
import numpy as np
......@@ -21,7 +19,9 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
......@@ -37,36 +37,7 @@ logger = init_logger(__name__)
# Here we utilize the behavior that out-of-bound index is ignored.
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
class ExecutionMode(enum.Enum):
PREFILL = enum.auto()
DECODE = enum.auto()
PREFIX_PREFILL = enum.auto()
def is_prefill(self) -> bool:
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
@dataclass
class PromptDecodeInfo:
prompt_req_ids: List[str]
decode_req_ids: List[str]
prompt_scheduled_tokens: List[int]
@dataclass
class PromptData:
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: PallasMetadata
@dataclass
class DecodeData:
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional[PallasMetadata] = None
INVALID_TOKEN_ID = -1
class TPUModelRunner:
......@@ -113,8 +84,6 @@ class TPUModelRunner:
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.model: Optional[nn.Module] = None
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
......@@ -134,50 +103,48 @@ class TPUModelRunner:
# KV caches for forward pass
self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = []
# Cached torch/numpy tensors
self.num_swaps = 2
self.cur_swap_id = 0
self.input_ids_cpu = []
self.input_ids_np = []
self.input_positions_cpu = []
self.input_positions_np = []
self.slot_mapping_cpu = []
self.slot_mapping_np = []
self.prompt_context_lens_cpu = []
self.prompt_effective_query_lens_cpu = []
self.decode_context_lens_cpu = []
self.decode_context_lens_np = []
for _ in range(self.num_swaps):
self.input_ids_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.input_ids_np.append(self.input_ids_cpu[-1].numpy())
self.input_positions_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.input_positions_np.append(
self.input_positions_cpu[-1].numpy())
self.slot_mapping_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int64,
device="cpu"))
self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy())
self.prompt_context_lens_cpu.append(
torch.empty((1), dtype=torch.int32, device="cpu"))
self.prompt_effective_query_lens_cpu.append(
torch.empty((1), dtype=torch.int32, device="cpu"))
self.decode_context_lens_cpu.append(
torch.empty(self.max_num_tokens,
dtype=torch.int32,
device="cpu"))
self.decode_context_lens_np.append(
self.decode_context_lens_cpu[-1].numpy())
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu")
self.input_ids_np = self.input_ids_cpu.numpy()
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu")
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
# self.input_batch.block_table has a shape of [max_num_reqs,
# max_num_blocks_per_req]. To reduce the number of recompilation,
# we want the block_table.shape[0] to be num_tokens.
# To make the block_table to be compatible with the paged attention
# kernel, we want the block_table[1] to be multiple of
# NUM_KV_PAGES_PER_BLOCK.
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, padded_max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
......@@ -191,7 +158,7 @@ class TPUModelRunner:
the input GPU tensors for the model.
Returns:
True if there is a new/resumed/paused/finished request in the batch.
True if there is a new/resumed/paused/finished request.
If False, we can skip copying SamplingMetadata to the GPU.
"""
# Remove finished requests from the cached states.
......@@ -303,9 +270,6 @@ class TPUModelRunner:
self.input_batch.condense(removed_req_indices)
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def swap_step(self):
self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps
def get_model(self) -> nn.Module:
assert self.model is not None
return self.model
......@@ -345,238 +309,124 @@ class TPUModelRunner:
return kv_cache_spec
def _get_prompts_and_decodes(
self,
scheduler_output: "SchedulerOutput",
) -> PromptDecodeInfo:
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Traverse decodes first
decode_req_ids = []
for i in range(num_reqs):
req_id = self.input_batch.req_ids[i]
assert req_id is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
if num_computed_tokens < num_prompt_tokens:
# This is prompt
break
# This is decode
assert num_scheduled_tokens == 1
decode_req_ids.append(req_id)
# Traverse prompts
prompt_req_ids = []
prompt_scheduled_tokens = []
for i in range(len(decode_req_ids), num_reqs):
req_id = self.input_batch.req_ids[i]
# Get the number of scheduled tokens for each request.
num_scheduled_tokens_per_req = []
max_num_scheduled_tokens_all_reqs = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
assert req_id is not None
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i]
num_prompt_tokens = self.input_batch.num_prompt_tokens[i]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
prompt_req_ids.append(req_id)
prompt_scheduled_tokens.append(num_scheduled_tokens)
return PromptDecodeInfo(prompt_req_ids, decode_req_ids,
prompt_scheduled_tokens)
def _prepare_prompt(self, req_index: int,
num_scheduled_tokens: int) -> PromptData:
num_computed_tokens = self.input_batch.num_computed_tokens_cpu[
req_index]
num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index]
# Must be prompt
assert num_computed_tokens < num_prompt_tokens
# Prompt len
prompt_len = num_scheduled_tokens
padded_prompt_len = _get_padded_prompt_len(prompt_len)
assert padded_prompt_len <= self.max_model_len
# Seq len
seq_len = num_computed_tokens + prompt_len
padded_seq_len = num_computed_tokens + padded_prompt_len
# Input tokens
input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[
req_index, num_computed_tokens:padded_seq_len]
input_tokens_cpu[prompt_len:] = 0
# Input positions
input_positions_np = self.input_positions_np[
self.cur_swap_id][:padded_prompt_len]
np.add(num_computed_tokens,
self.arange_np[:padded_prompt_len],
out=input_positions_np)
input_positions_np[prompt_len:] = 0
# Slot mapping
block_table_np = \
self.input_batch.block_table.get_numpy_array()
block_numbers_np = block_table_np[req_index, input_positions_np //
self.block_size]
block_offsets_np = input_positions_np % self.block_size
slot_mapping_np = self.slot_mapping_np[
self.cur_swap_id][:padded_prompt_len]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[prompt_len:] = _PAD_SLOT_ID
# Block table
block_table_cpu = None
if num_computed_tokens > 0:
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = block_table_cpu[req_index]
# Context len
self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0
if num_computed_tokens > 0:
self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len
# Effective query len
self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len
# Get final tensors
input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device)
input_positions = self.input_positions_cpu[
self.cur_swap_id][:padded_prompt_len].reshape(1,
-1).to(self.device)
slot_mapping = self.slot_mapping_cpu[
self.cur_swap_id][:padded_prompt_len].reshape(1,
-1).to(self.device)
block_table = block_table_cpu.reshape(1, -1).to(
self.device) if block_table_cpu is not None else None
context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to(
self.device)
effective_query_lens = self.prompt_effective_query_lens_cpu[
self.cur_swap_id].to(self.device)
self.swap_step()
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=1,
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
return PromptData(input_tokens, input_positions, attn_metadata)
def _prepare_decode(
self,
decode_req_ids: List[str],
) -> DecodeData:
# Batch size
batch_size = len(decode_req_ids)
padded_batch_size = _get_padded_batch_size(batch_size)
assert padded_batch_size <= self.max_model_len
# Init [0 .. batch_size - 1]
req_indices_np = self.arange_np[:padded_batch_size]
# Input positions
input_positions_np = self.input_positions_np[
self.cur_swap_id][:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
0,
out=input_positions_np)
input_positions_np[batch_size:] = 0
input_positions_cpu = self.input_positions_cpu[
self.cur_swap_id][:padded_batch_size]
# Input tokens
token_indices_np = (
input_positions_np +
req_indices_np * self.input_batch.token_ids_cpu.shape[1])
input_tokens_cpu = self.input_ids_cpu[
self.cur_swap_id][:padded_batch_size]
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens_per_req.append(num_tokens)
max_num_scheduled_tokens_all_reqs = max(
max_num_scheduled_tokens_all_reqs, num_tokens)
num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
dtype=np.int32)
assert max_num_scheduled_tokens_all_reqs > 0
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# For each scheduled token, what are the corresponding req index.
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens_per_req)
# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# For each scheduled token, what is its position in corresponding req.
arange = np.concatenate(
[self.arange_np[:n] for n in num_scheduled_tokens_per_req])
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_np),
out=input_tokens_cpu)
input_tokens_cpu[batch_size:] = 0
# Slot mapping
block_table_indices_np = (
req_indices_np * self.max_num_blocks_per_req +
input_positions_np // self.block_size)
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
# req_indices: # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens_per_req,
out=self.query_start_loc_np[1:num_reqs + 1])
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens_per_req)
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_number(
total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
self.input_ids = self.input_ids_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = self.slot_mapping_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
padded_block_table = self.block_table_cpu[:
padded_total_num_scheduled_tokens]
padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
padded_block_table = padded_block_table.to(self.device)
query_start_loc = self.query_start_loc_cpu[:
padded_total_num_scheduled_tokens
+ 1].to(self.device)
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(
self.device)
block_numbers_np = block_table_cpu.flatten(
)[block_table_indices_np].numpy()
block_offsets_np = input_positions_np % self.block_size
slot_mapping_np = self.slot_mapping_np[
self.cur_swap_id][:padded_batch_size]
np.add(block_numbers_np * self.block_size,
block_offsets_np,
out=slot_mapping_np)
slot_mapping_np[batch_size:] = _PAD_SLOT_ID
block_table_cpu = block_table_cpu[:padded_batch_size]
# Context lens
context_lens_np = self.decode_context_lens_np[
self.cur_swap_id][:padded_batch_size]
np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size],
1,
out=context_lens_np)
context_lens_np[batch_size:] = 0
# Get final tensors
input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device)
input_positions = input_positions_cpu.reshape(-1, 1).to(self.device)
slot_mapping = self.slot_mapping_cpu[
self.cur_swap_id][:padded_batch_size].reshape(-1,
1).to(self.device)
block_table = block_table_cpu.to(self.device)
context_lens = self.decode_context_lens_cpu[
self.cur_swap_id][:padded_batch_size].to(self.device)
self.swap_step()
# Attn metadata
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=padded_batch_size,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_table,
context_lens=context_lens,
effective_query_lens=None,
block_tables=padded_block_table,
context_lens=seq_lens,
query_start_loc=query_start_loc,
num_seqs=num_reqs,
)
return DecodeData(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
@torch.no_grad()
def execute_model(
......@@ -586,118 +436,81 @@ class TPUModelRunner:
# Update cached state
self._update_states(scheduler_output)
# If necessary, swap decodes/prompts to have all decodes on the start
ensure_decodes_first(self.input_batch)
# Prepare prompts/decodes info
pd_info = self._get_prompts_and_decodes(scheduler_output)
# Init
num_prompts = len(pd_info.prompt_req_ids)
num_decodes = len(pd_info.decode_req_ids)
decode_data = None
sampled_token_ids = [0] * self.input_batch.num_reqs
# Run each prompt individually
is_first = True
for i in range(num_prompts):
req_id = pd_info.prompt_req_ids[i]
req_index = num_decodes + i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i]
prompt_len = num_scheduled_tokens
seq_len = req_state.num_computed_tokens + num_scheduled_tokens
# Prepare first prompt
if is_first:
prompt_data = self._prepare_prompt(req_index,
num_scheduled_tokens)
is_first = False
# Run forward pass
with set_forward_context(prompt_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions,
self.kv_caches)
# In parallel to TPU execution, prepare the next iteration
if i < num_prompts - 1:
# There is next prompt => prepare it
prompt_data = self._prepare_prompt(
req_index + 1, pd_info.prompt_scheduled_tokens[i + 1])
elif i == num_prompts - 1 and num_decodes > 0:
# There is next decode => prepare it
decode_data = self._prepare_decode(pd_info.decode_req_ids)
# Update cached state (if prompt is fully done)
if seq_len >= len(req_state.prompt_token_ids):
# Transfer sampled tokens from TPU to CPU
selected_token_ids_cpu = selected_token_ids.cpu()
# Get output token
token_id = selected_token_ids_cpu[prompt_len - 1].item()
sampled_token_ids[req_index] = token_id
# Add output token to the request
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Run decodes (a single batch)
if num_decodes > 0:
# Prepare decode (if was not yet prepared)
if decode_data is None:
decode_data = self._prepare_decode(pd_info.decode_req_ids)
# Run forward pass
with set_forward_context(decode_data.attn_metadata,
self.vllm_config):
assert self.model is not None
selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions,
self.kv_caches)
# Transfer sampled tokens from TPU to CPU
decode_token_ids_cpu = selected_token_ids.cpu()
# Convert to list
decode_token_ids_list = decode_token_ids_cpu.tolist()
# Update cached state for each decode request
for i in range(num_decodes):
req_id = pd_info.decode_req_ids[i]
req_index = i
assert req_index == self.input_batch.req_id_to_index[
req_id] # TODO: Remove
req_state = self.requests[req_id]
seq_len = req_state.num_computed_tokens + 1
token_id = decode_token_ids_list[i]
sampled_token_ids[req_index] = token_id
self.input_batch.token_ids_cpu[req_index, seq_len] = token_id
self.input_batch.num_tokens[req_index] += 1
req_state.output_token_ids.append(token_id)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
token_ids=self.input_ids,
position_ids=self.position_ids,
kv_caches=self.kv_caches,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
num_reqs = self.input_batch.num_reqs
logits_indices = logits_indices[:num_reqs]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
# Then, let's update the cache state.
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
# Create output.
all_req_ids = pd_info.decode_req_ids + pd_info.prompt_req_ids
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
for req_id in all_req_ids:
for req_id in self.input_batch.req_ids[:num_reqs]:
prompt_logprobs_dict[req_id] = None
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1
else:
valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [
seq.tolist()
for seq in selected_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
model_runner_output = ModelRunnerOutput(
req_ids=all_req_ids,
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output
def load_model(self) -> None:
......@@ -731,185 +544,63 @@ class TPUModelRunner:
self,
kv_caches,
num_tokens: int,
seq_len: Optional[int] = None,
exec_mode: Optional[ExecutionMode] = None,
) -> None:
assert seq_len is not None
assert exec_mode is not None
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
token_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((num_tokens, seq_len),
dtype=torch.int64,
device=self.device)
if exec_mode == ExecutionMode.PREFILL:
attn_metadata = PallasMetadata(
num_prefills=num_tokens,
num_prefill_tokens=num_tokens * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=None,
context_lens=None,
effective_query_lens=None,
)
else:
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
block_tables = torch.zeros(
(num_tokens, self.max_num_blocks_per_req),
dtype=torch.int32,
device=self.device)
effective_query_lens = torch.ones_like(context_lens)
attn_metadata = PallasMetadata(
num_prefills=num_tokens,
num_prefill_tokens=num_tokens * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
effective_query_lens=effective_query_lens,
)
else:
assert seq_len == 1
token_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((num_tokens, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((num_tokens, seq_len),
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros(
(num_tokens, self.max_num_blocks_per_req),
dtype=torch.int32,
device=self.device)
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
attn_metadata = PallasMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_tokens * seq_len,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
block_tables=block_tables,
context_lens=context_lens,
)
input_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
query_lens = [1] * num_tokens
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
context_lens = torch.ones((num_tokens, ),
dtype=torch.int32,
device=self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_tokens,
)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(token_ids, position_ids, kv_caches)
self.model(input_ids, position_ids, kv_caches)
def capture_model(self) -> None:
"""Compile the model."""
# Prefill
logger.info(
"Compiling the model with different input shapes for prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
seq_len = seq_len * 2
end = time.time()
logger.info(" -- Compilation for prefill done in %.2f [secs].",
end - start)
# Prefix prefill
if self.scheduler_config.enable_chunked_prefill:
logger.info("Compiling the model with different input shapes for "
"prefix prefill:")
start = time.time()
for batch_size in [1]:
seq_len = 16
while seq_len <= self.model_config.max_model_len:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.PREFIX_PREFILL)
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size,
seq_len)
num_tokens = batch_size * seq_len
if (num_tokens
>= self.scheduler_config.max_num_batched_tokens):
break
seq_len = seq_len * 2
end = time.time()
logger.info(
" -- Compilation for prefix prefill done in %.2f [secs].",
end - start)
# Decode
logger.info(
"Compiling the model with different input shapes for decode:")
start = time.time()
seq_len = 1
batch_size = 8 # Must be in sync with _get_padded_batch_size()
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
num_tokens = 16
while True:
self.dummy_run(self.kv_caches,
batch_size,
seq_len,
exec_mode=ExecutionMode.DECODE)
self.dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step()
xm.wait_device_ops()
logger.info(" batch_size: %d, seq_len: %d", batch_size, seq_len)
if batch_size >= self.scheduler_config.max_num_seqs:
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
end = time.time()
logger.info(" -- Compilation for decode done in %.2f [secs].",
end - start)
num_tokens *= 2
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
......@@ -965,12 +656,8 @@ class ModelWrapperV1(nn.Module):
"""Executes the forward pass of the model and samples the next token.
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
num_samples: Number of samples to draw from each logits vector.
token_ids: The input token IDs of shape [num_tokens].
position_ids: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
"""
......@@ -982,6 +669,7 @@ class ModelWrapperV1(nn.Module):
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
......@@ -997,103 +685,22 @@ class ModelWrapperV1(nn.Module):
attn_metadata.slot_mapping = slot_mapping
assert self.model is not None
hidden_states = self.model(token_ids, position_ids)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, None)
# Greedy sampling.
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
return argmax_token_ids
def swap_positions(b: InputBatch, id_1, id_2):
assert id_1 != id_2
req_id_1 = b.req_ids[id_1]
req_id_2 = b.req_ids[id_2]
assert req_id_1 is not None
assert req_id_2 is not None
assert id_1 == b.req_id_to_index[req_id_1]
assert id_2 == b.req_id_to_index[req_id_2]
b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1]
b.req_id_to_index[req_id_1], b.req_id_to_index[
req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1]
ids = [id_1, id_2]
rev_ids = [id_2, id_1]
b.num_tokens[ids] = b.num_tokens[rev_ids]
b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids]
b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids]
b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids]
b.block_table.swap_row(id_1, id_2)
b.temperature_cpu[ids] = b.temperature_cpu[rev_ids]
b.top_p_cpu[ids] = b.top_p_cpu[rev_ids]
b.top_k_cpu[ids] = b.top_k_cpu[rev_ids]
b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids]
b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids]
b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids]
b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[
id_1]
gen_1 = b.generators.pop(id_1, None)
gen_2 = b.generators.pop(id_2, None)
if gen_1 is not None:
b.generators[id_2] = gen_1
if gen_2 is not None:
b.generators[id_1] = gen_2
def ensure_decodes_first(b: InputBatch):
num_reqs = b.num_reqs
while True:
# Find the first prompt index
first_prompt_index = None
for i in range(num_reqs):
if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]:
first_prompt_index = i
break
if first_prompt_index is None:
break
# Find the last decode index
last_decode_index = None
for i in reversed(range(num_reqs)):
if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]:
last_decode_index = i
break
if last_decode_index is None:
break
# Sanity
assert first_prompt_index != last_decode_index
# Check if done
if first_prompt_index > last_decode_index:
break
# Swap
swap_positions(b, first_prompt_index, last_decode_index)
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
)
return hidden_states
def _get_padded_prompt_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
return 1 << (x - 1).bit_length()
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
logits = self.model.compute_logits(hidden_states, sampling_metadata)
return logits
def _get_padded_batch_size(batch_size: int) -> int:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8.
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
......@@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
logger = init_logger(__name__)
......@@ -126,9 +126,7 @@ class TPUWorker:
self.model_runner.dummy_run(
runner_kv_caches,
num_tokens=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
exec_mode=ExecutionMode.PREFILL,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
# Synchronize before measuring the memory usage.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment