Unverified Commit e7596371 authored by weiyu's avatar weiyu Committed by GitHub
Browse files

[Refactor][TPU] Remove torch_xla path and use tpu-inference (#30808)


Signed-off-by: default avatarWei-Yu Lin <weiyulin@google.com>
Signed-off-by: default avatarweiyu <62784299+weiyu0824@users.noreply.github.com>
parent 0dd5dee9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class TPUModelLoader(DefaultModelLoader):
"""
A TPU model loader for model loading under SPMD mode.
"""
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
mesh: xs.Mesh | None = None,
) -> nn.Module:
# Initialize model and load weights on CPU. Then, during SPMD partition,
# weights are sharded and transferred to TPUs.
self.counter_before_loading_weights = time.perf_counter()
model_config = vllm_config.model_config
assert model_config.quantization is None, "Quantization not supported"
target_device = torch.device("cpu")
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
load_format = vllm_config.load_config.load_format
if load_format != "dummy":
weights_to_load = {name for name, _ in model.named_parameters()}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights
- self.counter_before_loading_weights,
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
else:
logger.info("Use dummy weight during weight loading.")
process_weights_after_loading(model, model_config, target_device)
counter_before_partition = time.perf_counter()
model = model.eval()
model = model.to("xla")
shard_model(model, mesh)
counter_after_partition = time.perf_counter()
logger.info(
"Partition model took %.2f seconds",
counter_after_partition - counter_before_partition,
)
# Ensure the model is properly loaded.
self._check_model_is_loaded(mesh, model)
# Need to torch compile after model sharding are done. Because the
# compiler hints ('xs.mark_sharding') are torch ops.
if not model_config.is_multimodal_model:
model.model = torch.compile(model.model, backend="openxla")
else:
model.language_model.model = torch.compile(
model.language_model.model, backend="openxla"
)
return model
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
"""
Ensure the model is properly loaded.
1. All model parameters and buffers are on XLA device.
2. Non-SPMD friendly layers are replaced as expected.
"""
device = xm.xla_device()
device_type = str(device.type)
# Check parameters
for name, param in model.named_parameters():
assert param.device.type == device_type, (
f"Parameter {name} is on {param.device.type} instead of {device_type}"
)
# Check buffers
for name, buffer in model.named_buffers():
assert buffer.device.type == device_type, (
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
)
for module in model.modules():
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
raise AssertionError(
"QKVParallelLinear should be replaced by \
XlaQKVParallelLinear under SPMD mode."
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from typing import TYPE_CHECKING, Optional, cast
import torch
from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
ParamsType: TypeAlias = SamplingParams | PoolingParams
else:
BlockSize = None
VllmConfig = None
PoolingParams = None
ParamsType = None
logger = init_logger(__name__)
USE_TPU_INFERENCE = False
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_name: str = "tpu"
device_type: str = "tpu"
dispatch_key: str = "XLA"
ray_device_key: str = "TPU"
dist_backend: str = "gloo"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"
supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"]
additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"]
@classmethod
def import_kernels(cls) -> None:
# Do not import vllm._C
with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas V1 backend.")
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.PALLAS,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention"
f"Supported backends are: {cls.get_supported_vit_attn_backends()}."
)
logger.info_once(f"Using backend {backend} for vit attention.")
return backend
logger.info_once(
f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention."
)
return AttentionBackendEnum.PALLAS
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.tpu.set_device(device)
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
chip_type, _ = device.get_local_chips()
return f"TPU {chip_type.name}"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
@classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
return torch.finfo(dtype).min, torch.finfo(dtype).max
@classmethod
def can_update_inplace(cls):
return False
@classmethod
def get_lora_vocab_padding_size(cls) -> int:
return 1
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationMode, CUDAGraphMode
cache_config = vllm_config.cache_config
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = cast(BlockSize, 16)
compilation_config = vllm_config.compilation_config
# TPU only supports DYNAMO_TRACE_ONCE compilation mode
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
logger.info(
"[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\
disabling cudagraph."
)
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
if (
compilation_config.cudagraph_mode is None
or compilation_config.cudagraph_mode.max_cudagraph_mode()
!= CUDAGraphMode.NONE
):
logger.info(
"[TPU] CUDA graph is not supported on TPU, disabling cudagraphs."
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
if compilation_config.backend == "":
compilation_config.backend = "openxla"
assert vllm_config.speculative_config is None, (
"TPU does not support speculative decoding"
)
model_config = vllm_config.model_config
if model_config is not None and model_config.dtype in (
torch.float16,
torch.float32,
):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.",
model_config.dtype,
)
model_config.dtype = torch.bfloat16
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment]
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend"
)
if (
scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"TPU does not support running Multimodal models"
" without setting `--disable_chunked_mm_input`. "
"Forcing --disable_chunked_mm_input."
)
scheduler_config.disable_chunked_mm_input = True
if model_config and model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.model_config.max_model_len,
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: ParamsType,
processed_inputs: ProcessorInputs,
) -> None:
"""Raises if this request is unsupported on this platform"""
from vllm.sampling_params import SamplingParams, SamplingType
if (
isinstance(params, SamplingParams)
and params.sampling_type == SamplingType.RANDOM_SEED
):
raise ValueError("Torch XLA does not support per-request seed.")
@classmethod
@torch.compile(backend="openxla")
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device)
@classmethod
@torch.compile(backend="openxla")
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""tpu blocks to cpu blocks"""
torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True)
dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu()
@classmethod
def use_sync_weight_loader(cls) -> bool:
return True
@classmethod
def check_max_model_len(cls, max_model_len: int) -> int:
"""
Check max_model_len for the current platform.
"""
logger.warning(
"--max-model-len is not specified, "
"it's currently using model's default length %d, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.",
max_model_len,
)
return max_model_len
try:
from tpu_inference.platforms import (
......@@ -291,5 +14,7 @@ try:
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True
except ImportError:
logger.info("tpu_inference not found, using vLLM's TpuPlatform")
logger.error(
"tpu_inference not found, please install tpu_inference to run vllm on TPU"
)
pass
......@@ -186,20 +186,6 @@ class UsageMessage:
except Exception:
return False
def _report_torch_xla_usage(self) -> bool:
try:
import torch_xla
self.gpu_count = torch_xla.runtime.world_size()
self.gpu_type = torch_xla.tpu.get_tpu_type()
self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[
"bytes_limit"
]
self.cuda_runtime = "torch_xla"
return True
except Exception:
return False
def _report_usage_once(
self,
model_architecture: str,
......@@ -217,9 +203,7 @@ class UsageMessage:
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
if current_platform.is_tpu(): # noqa: SIM102
if (not self._report_tpu_inference_usage()) and (
not self._report_torch_xla_usage()
):
if not self._report_tpu_inference_usage():
logger.exception("Failed to collect TPU information")
self.provider = _detect_cloud_provider()
self.architecture = platform.machine()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionType,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, next_power_of_2
logger = init_logger(__name__)
# TPU requires the head size to be a multiple of 128.
TPU_HEAD_SIZE_ALIGNMENT = 128
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
# from to fp32 directly. That's why it has a dtype mapping different from GPU
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.float8_e4m3fn,
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
"int8": torch.int8,
"uint8": torch.uint8,
}
try:
import tpu_inference # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices),
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
)
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
"int num_slices_per_block)"
"-> Tensor",
)
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(
kv,
slot_mapping,
kv_cache,
num_kv_update_slices,
page_size,
num_slices_per_block,
)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(
kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int,
) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
# block_tables within the PallasMetadata constitute almost the entire SMEM
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
# we simply make sure that the size is smaller than half of SMEM capacity.
@staticmethod
def get_min_page_size(vllm_config: VllmConfig) -> int:
max_num_page_per_req = (
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
)
min_page_size = cdiv(
vllm_config.model_config.max_model_len, max_num_page_per_req
)
min_page_size = 1 << (min_page_size - 1).bit_length()
return min_page_size
@staticmethod
def get_max_num_seqs(model_len: int, page_size: int) -> int:
num_page_per_req = cdiv(model_len, page_size)
return 1024 * 1024 // 2 // num_page_per_req // 4
# TPU has limited SREGs (scalar registers), if page_size is too small, we
# can spill SREGs easily which leads to bad performance. The strategy we
# apply here is trying to split max-model-len to 16 pages which make the
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
@staticmethod
def get_page_size(vllm_config: VllmConfig) -> int:
# TODO: This is a temporary fix for vmem OOM.
# For long model length, we use 16 page-size to avoid too much
# VMEM spill. A more robust solution should be implemented to
# handle VREG spills.
if vllm_config.model_config.max_model_len > 8192:
return 16
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
if page_size <= 16:
return 16
if page_size >= 256:
return 256
return page_size
@dataclass
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: torch.Tensor
num_kv_update_slices: torch.Tensor
num_slices_per_kv_cache_update_block: int
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: list[float] | None,
sliding_window: int | None,
kv_cache_dtype: str,
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: int | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
self.kv_cache_quantized_dtype = None
if kv_cache_dtype != "auto":
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
kv_cache_dtype.lower().strip()
)
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: shape =
[num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionBackendImpl"
)
# For determine_available_memory case.
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
padded_head_size = (
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
query = torch.nn.functional.pad(
query, (0, padded_head_size - self.head_size), value=0.0
)
key = torch.nn.functional.pad(
key, (0, padded_head_size - self.head_size), value=0.0
)
value = torch.nn.functional.pad(
value, (0, padded_head_size - self.head_size), value=0.0
)
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
# Write input keys and values to the KV cache.
# Skip this if sharing KV cache with an earlier attention layer.
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(
key,
value,
kv_cache,
slot_mapping,
attn_metadata.num_slices_per_kv_cache_update_block,
attn_metadata.num_kv_update_slices,
self.kv_cache_quantized_dtype,
layer._k_scale_float,
layer._v_scale_float,
)
if self.kv_cache_quantized_dtype is not None and (
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
):
raise ValueError("k_scale_float and v_scale_float must be non-zero")
output = torch.ops.xla.ragged_paged_attention(
query,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
)
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
output = output[:, :, : self.head_size]
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_slices_per_kv_cache_update_block: int,
num_kv_update_slices: torch.Tensor,
kv_cache_quantized_dtype: torch.dtype | None = None,
k_scale: float = 1.0,
v_scale: float = 1.0,
) -> None:
"""Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
num_slices_per_kv_cache_update_block: int
"""
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
if kv_cache_quantized_dtype is not None:
dtype_info = torch.finfo(kv_cache_quantized_dtype)
key = key.to(torch.float32) / k_scale
# NOTE: clamp is added here to avoid out of range of quantized dtype
key = torch.clamp(key, dtype_info.min, dtype_info.max)
key = key.to(kv_cache_quantized_dtype)
value = value.to(torch.float32) / v_scale
value = torch.clamp(value, dtype_info.min, dtype_info.max)
value = value.to(kv_cache_quantized_dtype)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
kv_cache = kv_cache.flatten(0, 1)
new_kv_cache = torch.ops.xla.kv_cache_update_op(
kv,
slot_mapping,
kv_cache,
num_kv_update_slices,
page_size,
num_slices_per_kv_cache_update_block,
)
# NOTE: the in-place copy will be optimized away by XLA compiler.
kv_cache.copy_(new_kv_cache)
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
if dtype.is_floating_point:
try:
return torch.finfo(dtype).bits
except TypeError:
pass
elif dtype.is_complex:
if dtype is torch.complex32:
return 32
elif dtype is torch.complex64:
return 64
elif dtype is torch.complex128:
return 128
else:
try:
return torch.iinfo(dtype).bits
# torch.iinfo cannot support int4, int2, bits8...
except TypeError:
pass
str_dtype = str(dtype)
# support torch.int4, torch.int5, torch.uint5...
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
return int(str_dtype[-1])
raise TypeError(f"Getting the bit width of {dtype} is not supported")
def get_dtype_packing(dtype):
bits = dtype_bits(dtype)
if 32 % bits != 0:
raise ValueError(
f"The bit width must be divisible by 32, but got bits={bits}, "
"dtype={dtype}"
)
return 32 // bits
def get_page_size_bytes(
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
) -> int:
"""Returns the size in bytes of one page of the KV cache."""
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
num_combined_kv_heads = num_kv_heads * 2
# NOTE: for the implicit padding in XLA
packing = get_dtype_packing(kv_cache_dtype)
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
return (
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import bisect
import gc
import time
from typing import TYPE_CHECKING, Any, cast
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
# TPU XLA related
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
from vllm.config import (
ParallelConfig,
VllmConfig,
get_layers_from_vllm_config,
update_config,
)
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMappingType
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
supports_transcription,
)
from vllm.model_executor.models.interfaces_base import (
is_pooling_model,
is_text_generation_model,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils.math_utils import cdiv, prev_power_of_2
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.pallas import (
TPU_STR_DTYPE_TO_TORCH_DTYPE,
PallasAttentionBackend,
PallasMetadata,
get_page_size_bytes,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin,
KVConnectorOutput,
)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
sanity_check_mm_encoder_outputs,
)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
#########################################################
# Ways to avoid recompilation
#########################################################
#
# The model executor has two primary components:
# 1. preparing the model and sampler inputs
# 2. executing the model and sampler.
# The core idea is to avoid any TPU computation during input preparation. For
# better compilation tracking and increased flexibility, the model execution and
# sampler are divided into several distinct components.
#
# Below are the detailed steps:
#
# Step 1
# It is recommended to avoid TPU operations when preparing the model and sampler
# inputs. CPU tensors can be prepared and transferred to the XLA device using
# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids
# compilation.
#
# Step 2
# The TPU execution should be decomposed into subgraphs (4 at the moment):
# 1. the main model
# 2. selecting hidden states for each request
# 3. sampler
# 4. encoder.
# Each subgraph should be decorated in a torch.compile. This is used to make
# sure that we have the same subgraph topology in both dummy_run and
# xecute_model. The results from these subgraphs should either be passed to
# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for
# subsequent processing on the CPU.
#
# Step 3
# The dummy_run should be comprehensive, ensuring all potential input shapes and
# branch predictions are included as subgraph inputs to facilitate
# pre-compilation.
class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
original_parallel_config: ParallelConfig | None = None,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.original_parallel_config = original_parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device_config = vllm_config.device_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
# SPMD Related
self.use_spmd = envs.VLLM_XLA_USE_SPMD
if self.use_spmd:
num_devices = xr.global_runtime_device_count()
mesh_shape = (num_devices, 1)
device_ids = np.array(range(num_devices))
self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0
self._update_num_xla_graphs("init")
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
model_dtype = self.dtype
if isinstance(model_dtype, str):
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
else:
self.kv_cache_dtype = model_dtype
else:
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
self._hidden_states_dtype = self.dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.num_blocks_per_most_len_req = (
cdiv(self.most_model_len, self.block_size)
if self.most_model_len is not None
else None
)
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)
self.num_tokens_paddings = _get_token_paddings(
min_token_size=16,
max_token_size=scheduler_config.max_num_batched_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP,
)
# In case `max_num_tokens < max(num_tokens_paddings)` use the actual
# padded max value to pre-allocate data structures and pre-compile.
self.max_num_tokens = self.num_tokens_paddings[-1]
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, "attention"
)
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.inputs_embeds_size = model_config.get_inputs_embeds_size()
self.vocab_size = model_config.get_vocab_size()
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
model_config
)
# TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
self._num_slices_per_kv_cache_update_block = (
_get_num_slices_per_kv_cache_update_block(
get_page_size_bytes(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
kv_cache_dtype=self.kv_cache_dtype,
)
)
)
# Lazy initialization
self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# Initialize input batch early to avoid AttributeError in _update_states
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size],
kernel_block_sizes=[self.cache_config.block_size],
)
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
self.input_ids_cpu = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device="cpu"
)
self.positions_cpu = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device="cpu"
)
self.positions_np = self.positions_cpu.numpy()
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=torch.int32,
device="cpu",
)
# adjust num_reqs to avoid SMEM OOM.
self.num_reqs_most_model_len = (
min(
PallasAttentionBackend.get_max_num_seqs(
self.most_model_len, self.block_size
),
self.max_num_reqs,
)
if self.most_model_len is not None
else None
)
self.num_reqs_max_model_len = min(
PallasAttentionBackend.get_max_num_seqs(
self.max_model_len, self.block_size
),
self.max_num_reqs,
)
self.query_start_loc_cpu = torch.zeros(
self.max_num_tokens + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Only relevant for multimodal models
if self.supports_mm_inputs:
self.is_mm_embed_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64)
self.num_reqs_paddings = _get_req_paddings(
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs
)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
# tensors for structured decoding
self.grammar_bitmask_cpu = torch.zeros(
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.require_structured_out_cpu = torch.zeros(
(self.max_num_reqs, 1),
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.structured_decode_arange = torch.arange(
0, 32, device="cpu", pin_memory=self.pin_memory
)
self.mm_budget = (
MultiModalBudget(
self.model_config,
self.scheduler_config,
self.mm_registry,
)
if self.supports_mm_inputs
else None
)
if not self.use_spmd:
self.sample_from_logits_func = torch.compile(
self.sample_from_logits,
backend="openxla",
fullgraph=True,
dynamic=False,
)
else:
self.sample_from_logits_func = self.sample_from_logits
# For passing scheduler_output between successive
# execute_model() and sample_tokens() calls.
self.scheduler_output: SchedulerOutput | None = None
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
total_cached_graphs = xr.get_num_cached_compilation_graph()
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
if new_compiled_graphs == 0:
return
logger.info(
"Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str
)
self.num_xla_graphs += new_compiled_graphs
def _verify_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected during {}."
" num_xla_graphs = {} curr_cached_graph = {}".format(
case_str, self.num_xla_graphs, curr_cached_graph
)
)
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
"""Update the cached states and the persistent batch with the scheduler
output.
The updated states are used by the `_prepare_inputs` function to create
the input GPU tensors for the model.
Returns:
True if there is a new/resumed/paused/finished request.
If False, we can skip copying SamplingMetadata to the GPU.
"""
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
# them from the persistent batch but keep their cached states since
# they will be scheduled again sometime in the future.
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
cached_req_ids = self.input_batch.req_id_to_index.keys()
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
req_index = self.input_batch.remove_request(req_id)
assert req_index is not None
removed_req_indices.append(req_index)
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.sampling_params is not None, (
"Pooling is not supported in TPU yet"
)
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
mm_features=new_req_data.mm_features,
sampling_params=sampling_params,
pooling_params=None,
generator=None,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
if sampling_params and sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = (
self.input_batch.vocab_size
if sampling_params.prompt_logprobs == -1
else sampling_params.prompt_logprobs
)
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_id in req_data.resumed_req_ids
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not resumed_from_preemption:
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
block_ids.extend(new_ids)
else:
assert new_block_ids is not None
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
if new_block_ids is not None:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
# Fill the empty index or append to the end
req_index = removed_req_indices.pop() if removed_req_indices else None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module:
return self.model
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
model = self.get_model()
if not is_pooling_model(model):
return []
return list(model.pooler.get_supported_tasks())
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
if self.model_config.runner_type == "pooling":
tasks.extend(self.get_supported_pooling_tasks())
return tuple(tasks)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
block_size = self.vllm_config.cache_config.block_size
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items():
# Classic Attention path
if isinstance(attn_module, Attention):
if (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
if attn_module.attn_type == AttentionType.DECODER:
if isinstance(attn_module, ChunkedLocalAttention):
logger.warning_once(
"Using irope in Pallas is not supported yet, it "
"will fall back to global attention for long context."
)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
# MLAAttention path
elif isinstance(attn_module, MLAAttention):
if layer_name in kv_cache_spec:
continue
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
else:
continue
return kv_cache_spec
def _get_slot_mapping_metadata(
self, num_reqs, num_scheduled_tokens_per_req
) -> np.ndarray:
"""
Computes metadata for mapping slots to blocks in the key-value (KV)
cache for a batch of requests.
This function determines, for each request in the batch, how the
scheduled tokens are distributed across memory blocks, and generates
metadata needed to map slices of tokens to their corresponding positions
in the KV cache.
Args:
num_reqs (int): Number of requests in the current batch.
num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens
to be scheduled for each request.
Returns:
np.ndarray: A 2D array of shape (total_block_len, 3), where each row
contains:
- kv_cache_start_index (int): The starting index in the KV cache
for the corresponding slice.
- new_kv_start_index (int): The starting index in the new KV
cache for the corresponding slice.
- slice_len (int): The length of the slice.
"""
slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs]
slices_end = (
self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ num_scheduled_tokens_per_req
)
local_block_start_idx = slices_start // self.block_size
local_block_end_idx = (slices_end - 1) // self.block_size
no_repeat_req_indices = self.arange_np[:num_reqs]
global_block_start_idx = (
no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx
)
block_lens = local_block_end_idx - local_block_start_idx + 1
global_block_start_idx = np.repeat(global_block_start_idx, block_lens)
slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens])
global_block_indices = global_block_start_idx + slice_arange
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[global_block_indices].numpy()
total_block_len = np.sum(block_lens)
slot_mapping_slices = np.repeat(
np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0
)
cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32)
np.cumsum(block_lens, out=cu_block_lens[1:])
for req_idx in range(num_reqs):
slot_mapping_slices[cu_block_lens[req_idx]][0] = (
slices_start[req_idx] % self.block_size
)
slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = (
slices_end[req_idx] - 1
) % self.block_size + 1
slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0]
cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32)
np.cumsum(slice_lens, out=cu_slices_lens[1:])
kv_cache_start_indices = slot_mapping_slices[:, 0] + (
block_numbers * self.block_size
)
new_kv_start_indices = cu_slices_lens[:-1]
slot_mapping_metadata = np.stack(
[kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1
)
return slot_mapping_metadata
def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int):
assert scheduler_output.total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
assert start_index < num_reqs
# Get the number of scheduled tokens for each request.
use_max_model_len = self.most_model_len is None
num_scheduled_tokens_per_req = []
max_num_scheduled_tokens_all_reqs = 0
end_index = start_index
# Use either most_model_len or max_model_len depending on request size.
for i in range(start_index, num_reqs):
req_id = self.input_batch.req_ids[i]
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if (
not use_max_model_len
and self.most_model_len is not None
and num_tokens > self.most_model_len
):
use_max_model_len = True
num_scheduled_tokens_per_req.append(num_tokens)
if use_max_model_len:
if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len:
num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
: self.num_reqs_max_model_len
]
end_index = start_index + self.num_reqs_max_model_len
else:
end_index = num_reqs
else:
assert self.num_reqs_most_model_len is not None
if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len:
num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[
: self.num_reqs_most_model_len
]
end_index = start_index + self.num_reqs_most_model_len
else:
end_index = num_reqs
max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req)
num_scheduled_tokens_per_req = np.array(
num_scheduled_tokens_per_req, dtype=np.int32
)
total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req)
assert max_num_scheduled_tokens_all_reqs > 0
num_reqs = len(num_scheduled_tokens_per_req)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# For each scheduled token, what are the corresponding req index.
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req)
# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# For each scheduled token, what is its position in corresponding req.
arange = np.concatenate(
[self.arange_np[:n] for n in num_scheduled_tokens_per_req]
)
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np,
)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (
positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
)
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(
self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens],
)
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
np.cumsum(
num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1]
)
self.query_start_loc_np[num_reqs + 1 :] = 1
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ num_scheduled_tokens_per_req
)
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_token_len(
self.num_tokens_paddings, total_num_scheduled_tokens
)
# Zero out to avoid spurious values from prev iteration (last cp chunk)
self.input_ids_cpu[
total_num_scheduled_tokens:padded_total_num_scheduled_tokens
] = 0
self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to(
self.device
)
self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to(
self.device
)
if use_max_model_len:
block_tables = self.block_table_cpu[
: self.num_reqs_max_model_len, : self.max_num_blocks_per_req
]
block_tables[:num_reqs, : self.max_num_blocks_per_req] = (
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs]
)
query_start_loc = self.query_start_loc_cpu[
: self.num_reqs_max_model_len + 1
].to(self.device)
seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device)
else:
assert self.num_reqs_most_model_len is not None
block_tables = self.block_table_cpu[
: self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req
]
block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = (
self.input_batch.block_table[0].get_cpu_tensor()[
:num_reqs, : self.num_blocks_per_most_len_req
]
)
query_start_loc = self.query_start_loc_cpu[
: self.num_reqs_most_model_len + 1
].to(self.device)
seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device)
block_tables = block_tables.to(self.device)
# Calculate the slot mapping
slot_mapping_metadata = self._get_slot_mapping_metadata(
num_reqs, num_scheduled_tokens_per_req
)
num_kv_update_slices = slot_mapping_metadata.shape[0]
padded_num_slices = _get_padded_num_kv_cache_update_slices(
padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size
)
slot_mapping_metadata = np.pad(
slot_mapping_metadata,
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
constant_values=0,
)
slot_mapping_metadata = np.transpose(slot_mapping_metadata)
slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += (
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
)
self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping_metadata,
block_tables=block_tables,
context_lens=seq_lens,
query_start_loc=query_start_loc,
num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device),
num_kv_update_slices=torch.tensor(
[num_kv_update_slices], dtype=torch.int32, device=self.device
),
num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs
)
# Indices at which we sample (positions of last token in the sequence).
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
if self.lora_config is not None:
# We need to respect padding when activating LoRA adapters
padded_num_scheduled_tokens_per_req = np.copy(
num_scheduled_tokens_per_req
) # Copying to avoid accidental state corruption bugs
padded_num_scheduled_tokens_per_req[-1] += (
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
)
self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req)
layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
per_layer_attn_metadata = {
layer_name: attn_metadata for layer_name in layer_names
}
return (
per_layer_attn_metadata,
logits_indices,
padded_num_reqs,
num_reqs,
end_index,
)
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
# List of tuple (mm_hash, pos_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_feature = req_state.mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
model = cast(SupportsMultiModal, self.model)
encoder_outputs = []
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=self.pin_memory,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
torch_xla.sync(wait=False)
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
torch_xla.sync(wait=False)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=num_items,
)
if isinstance(curr_group_outputs, torch.Tensor):
encoder_outputs.append(curr_group_outputs)
else:
assert isinstance(curr_group_outputs, (list, tuple))
for output in curr_group_outputs:
encoder_outputs.append(output)
# Cache the encoder outputs.
# NOTE (NickLucche) here we diverge from logic in other runners, as we
# assume to only have whole mm items to process. Hence we avoid the
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
assert pos_info.is_embed is None, (
"Expected all positions to be contiguous and embeddings."
)
self.encoder_cache[mm_hash] = output
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[list[torch.Tensor], torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
padded_total_num_scheduled_tokens = _get_padded_token_len(
self.num_tokens_paddings, total_num_scheduled_tokens
)
is_mm_embed = self.is_mm_embed_cpu
is_mm_embed[:padded_total_num_scheduled_tokens] = False
mm_embeds = list[torch.Tensor]()
req_start_idx = 0
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
# the intrinsic dynamism that `gather_mm_placeholders` introduces.
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens,
)
assert start_idx < end_idx
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, (
"Expected all positions to be contiguous and embeddings."
)
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True
# Only whole mm items are processed
mm_embeds.append(encoder_output)
req_start_idx += num_scheduled_tokens
is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device)
return mm_embeds, is_mm_embed
def _get_model_inputs(
self,
input_ids: torch.Tensor,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None,
):
if self.supports_mm_inputs:
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds = self.model.embed_input_ids(
input_ids,
multimodal_embeddings=mm_embeds,
is_multimodal=is_mm_embed,
)
return None, inputs_embeds
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
return input_ids, None
@torch.no_grad()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | None:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
mm_embed_inputs = None
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
torch_xla.sync(wait=False)
self.scheduler_output = scheduler_output
self.mm_embed_inputs = mm_embed_inputs
return None
@torch.no_grad()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput:
if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # type: ignore[return-value]
scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None
self.mm_embed_inputs = None
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
start_index = 0
combined_selected_tokens: list[torch.Tensor] = []
combined_logprobs: list[LogprobsLists] = []
# NOTE: setup current batch's metadata for kv connector.
# Currently, only verified with NixlConnector
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
while start_index < self.input_batch.num_reqs:
attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = (
self._prepare_inputs(scheduler_output, start_index)
)
input_ids, inputs_embeds = self._get_model_inputs(
self.input_ids, mm_embed_inputs
)
torch_xla.sync(wait=False)
# Run the decoder
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=scheduler_output.total_num_scheduled_tokens,
):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
inputs_embeds=inputs_embeds,
)
hidden_states = self.select_hidden_states(hidden_states, logits_indices)
logits = self.compute_logits(hidden_states)
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch, padded_num_reqs, self.device
)
if grammar_output is not None:
require_struct_decoding, grammar_bitmask_padded, arange = (
self.prepare_structured_decoding_input(logits, grammar_output)
)
logits = self.structured_decode(
require_struct_decoding, grammar_bitmask_padded, logits, arange
)
selected_token_ids = self.sample_from_logits_func(
logits, tpu_sampling_metadata
)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it
# due to recompilations outside torch.compiled code, so just make
# sure `sample_from_logits` does not modify the logits in-place.
logprobs = (
self.gather_logprobs(logits, selected_token_ids)
if tpu_sampling_metadata.logprobs
else None
)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
combined_selected_tokens.append(selected_token_ids)
if tpu_sampling_metadata.logprobs:
combined_logprobs.append(logprobs.tolists())
start_index = end_index
# NOTE: current kv load and save get h2d/d2h copies involved.
# Those copies are blocking. Once they become async., kv_save
# should be called right after each single forward pass,
# instead of the forwards of the entire input batch.
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfers(
scheduler_output
)
selected_token_ids = torch.cat(combined_selected_tokens, dim=0)
if tpu_sampling_metadata.logprobs:
def concat_lists(input_lists):
result = []
for input_list in input_lists:
result.extend(input_list)
return result
logprobs_lists = LogprobsLists(
logprob_token_ids=concat_lists(
[lp.logprob_token_ids for lp in combined_logprobs]
),
logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]),
sampled_token_ranks=concat_lists(
[lp.sampled_token_ranks for lp in combined_logprobs]
),
)
else:
logprobs_lists = None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
num_reqs = self.input_batch.num_reqs
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (
req_state.num_computed_tokens
+ scheduler_output.num_scheduled_tokens[req_id]
)
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
assert all(
req_id is not None for req_id in self.input_batch.req_ids[:num_reqs]
), "req_ids contains None"
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
for req_id in self.input_batch.req_ids[:num_reqs]:
prompt_logprobs_dict[req_id] = None
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
# TODO: Keep in sync with gpu_model_runner.py, in particular
# the "else" case here
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Append sampled tokens
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens_no_spec[i] += 1
else:
valid_mask = selected_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [
seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens_no_spec[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[i, target_slice] = (
valid_sampled_token_ids[i]
)
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
kv_connector_output = (
None
if (finished_sending is None and finished_recving is None)
else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
)
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
self._verify_num_xla_graphs("execute_model")
return model_runner_output
def update_config(self, overrides: dict[str, Any]) -> None:
# TODO: TPU config may need extra validation
# https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, (
f"Config `{config_name}` not supported. "
f"Allowed configs: {allowed_config_names}"
)
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)
def load_model(self) -> None:
self.device = self.device_config.device
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
# process, the ranks can be different from the ranks internally assigned
# by the xm runtime. Therefore, there is a mismatch in the rank
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
# This is not a problem in linear layers because all-reduce is
# rank-agnostic. However, it matters for all-gather as the ranks
# determine the order of concatenating the output tensors.
# As a workaround, we use the xm's rank assignment only when loading
# the embedding weights.
xm_tp_rank = xr.global_ordinal()
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank,
):
try:
if self.use_spmd:
tpu_loader = TPUModelLoader(
load_config=self.vllm_config.load_config
)
model = tpu_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.vllm_config.model_config,
mesh=self.mesh,
)
else:
model_loader = get_model_loader(self.load_config)
logger.info("Loading model from scratch...")
model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config
)
except RuntimeError as e:
raise RuntimeError(
f"Unable to load model, a likely reason is the model is "
"too large for the current device's HBM memory. "
"Consider switching to a smaller model "
"or sharding the weights on more chips. "
f"See the detailed error: {e}"
) from e
if self.lora_config is not None:
model = self.load_lora_model(model, self.vllm_config, self.device)
replace_set_lora(model)
# Sync all pending XLA execution during model initialization and weight
# loading.
torch_xla.sync(wait=False)
xm.wait_device_ops()
if not hasattr(self, "model"):
self.model = model
self.sampler = TPUSampler()
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, (
"Cannot reload weights before model is loaded."
)
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)
@torch.no_grad()
def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None:
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = torch.zeros(
(num_tokens, self.inputs_embeds_size),
dtype=self.dtype,
device=self.device,
)
else:
input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device)
inputs_embeds = None
actual_num_reqs = min(num_tokens, num_reqs)
position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device)
padded_num_slices = _get_padded_num_kv_cache_update_slices(
num_tokens, self.max_num_reqs, self.block_size
)
num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(
self.device
)
slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to(
self.device
)
block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to(
self.device
)
query_lens = [1] * num_reqs
query_start_loc = torch.cumsum(
torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32
).to(self.device)
context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device)
num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_seqs,
num_kv_update_slices=num_kv_update_slices,
num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block,
)
if self.supports_mm_inputs:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
else:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1))
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys()
per_layer_attn_metadata = {
layer_name: attn_metadata for layer_name in layer_names
}
with (
self.maybe_select_dummy_loras(
self.lora_config, np.array([num_tokens], dtype=np.int32)
),
set_forward_context(per_layer_attn_metadata, self.vllm_config, 0),
):
out = self.model(
input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds
)
self._hidden_states_dtype = out.dtype
def _set_active_loras(
self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
torch_xla.sync(wait=False) # Captures input updates
super()._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
)
torch_xla.sync(wait=False) # Captures metadata updates
def _precompile_mm_encoder(self) -> None:
if not self.supports_mm_inputs:
return
# Pre-compile MM encoder for all supported data modalities.
hf_config = self.vllm_config.model_config.hf_config
mm_budget = self.mm_budget
assert mm_budget is not None
max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501
for mode, max_items_per_seq in max_items_per_seq_by_modality.items():
logger.info(
"Compiling Multimodal %s Encoder with different input shapes.", mode
)
start = time.perf_counter()
# No padding for MM encoder just yet.
for num_items in range(1, max_items_per_seq + 1):
logger.info(" -- mode: %s items: %d", mode, num_items)
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
mode,
num_items,
)
# Run multimodal encoder.
torch_xla.sync(wait=False)
mm_embeds = self.model.embed_multimodal(**batched_dummy_mm_inputs)
torch_xla.sync(wait=False)
num_patches = mm_embeds[0].shape[0]
items_size = num_patches * num_items
# NOTE (NickLucche) pre-compile `embed_input_ids` when mm
# embeddings are present. We assume `--disable-mm-chunked`,
# hence only whole items can be scheduled. This implies we just
# need to compile when `num_items` fit the (padded) `input_ids`
for num_tokens in self.num_tokens_paddings:
if num_tokens >= items_size:
# XLA Workaround: if torch.zeros(..device) is used, XLA
# compiles a scalar+expansion op, which won't match
# the graph generated at runtime. CPU->TPU must be used
placeholders_ids = torch.zeros(
num_tokens, dtype=torch.int32, device="cpu"
)
# Align placeholders and actual num mm_embeddings.
placeholders_ids[:items_size] = hf_config.image_token_index
placeholders_ids = placeholders_ids.to(self.device)
mm_mask = torch.tensor([False] * num_tokens)
mm_mask[:items_size] = True
mm_mask = mm_mask.to(self.device)
# Assign outputs or the graph will be cut short.
a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=([mm_embeds], mm_mask),
)
assert a is None
torch_xla.sync(wait=False)
# Pre-compile `embed_input_ids` when mm_embeddings are not
# present. Chunk is only made of text, no mm_placeholders.
for num_tokens in self.num_tokens_paddings:
placeholders_ids = torch.zeros(
num_tokens, dtype=torch.int32, device="cpu"
)
placeholders_ids = placeholders_ids.to(self.device)
a, b = self._get_model_inputs(
placeholders_ids,
mm_embed_inputs=None,
)
assert a is None
torch_xla.sync(wait=False)
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
"Multimodal %s Encoder compilation finished in in %.2f [secs].",
mode,
end - start,
)
def _precompile_backbone(self) -> None:
logger.info("Compiling the model with different input shapes.")
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(
num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
)
if self.most_model_len is not None:
self._dummy_run(
num_tokens,
self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req,
)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("model backbone")
def _precompile_select_hidden_states(self) -> None:
# Compile hidden state selection function for bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
logger.info("Compiling select_hidden_states with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_tokens in self.num_tokens_paddings:
dummy_hidden = torch.zeros(
(num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype
)
torch._dynamo.mark_dynamic(dummy_hidden, 0)
for num_reqs in self.num_reqs_paddings:
indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
torch._dynamo.mark_dynamic(indices, 0)
self.select_hidden_states(dummy_hidden, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs)
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
if num_reqs >= min(num_tokens, self.max_num_reqs):
break
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("select_hidden_states")
def _precompile_compute_logits(self) -> None:
logger.info("Compiling compute_logits with different input shapes.")
start = time.perf_counter()
hsize = self.model_config.get_hidden_size()
for num_reqs in self.num_reqs_paddings:
dummy_hidden = torch.zeros(
(num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype
)
torch._dynamo.mark_dynamic(dummy_hidden, 0)
self.compute_logits(dummy_hidden)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("compute_logits")
def _precompile_structured_decoding(self) -> None:
logger.info("Compiling structured_decoding with different input shapes.")
start = time.perf_counter()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros(
(num_reqs, self.vocab_size),
device=self.device,
dtype=self._hidden_states_dtype,
)
dummy_require_struct_decoding = self.require_structured_out_cpu[
:num_reqs
].to(self.device)
dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device)
# The first dimension of the above 3 dummy tensors cannot be
# mark_dynamic because some operations in structured_decode require
# them to be static.
arange = self.structured_decode_arange.to(self.device)
self.structured_decode(
dummy_require_struct_decoding,
dummy_grammar_bitmask,
dummy_logits,
arange,
)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("structured_decoding")
def _precompile_sample_from_logits(self) -> None:
logger.info("Compiling sample_from_logits with different input shapes.")
start = time.perf_counter()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros(
(num_reqs, self.vocab_size),
device=self.device,
dtype=self._hidden_states_dtype,
)
# The first dimension of dummy_logits cannot be mark_dynamic
# because some operations in the sampler require it to be static.
for all_greedy in [False, True]:
generate_params_if_all_greedy = not all_greedy
sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch,
num_reqs,
self.device,
generate_params_if_all_greedy,
)
sampling_metadata.all_greedy = all_greedy
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs], dtype=np.int32)
):
self.sample_from_logits_func(dummy_logits, sampling_metadata)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("sample_from_logits")
def _precompile_gather_logprobs(self) -> None:
logger.info("Compiling gather_logprobs with different input shapes.")
start = time.perf_counter()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros(
(num_reqs, self.vocab_size),
device=self.device,
dtype=self._hidden_states_dtype,
)
dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device)
with self.maybe_select_dummy_loras(
self.lora_config, np.array([num_reqs], dtype=np.int32)
):
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("gather_logprobs")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
"""
with self.maybe_setup_dummy_loras(self.lora_config):
self._precompile_mm_encoder()
self._precompile_backbone()
self._precompile_select_hidden_states()
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
def profile_run(
self,
num_tokens: int,
) -> None:
# Profile with multimodal encoder & encoder cache.
if self.supports_mm_inputs:
mm_config = self.model_config.multimodal_config
if mm_config is not None and mm_config.skip_mm_profiling:
logger.info(
"Skipping memory profiling for multimodal encoder and "
"encoder cache."
)
else:
mm_budget = self.mm_budget
assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
dummy_modality = mm_budget.get_modality_with_max_tokens()
max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
dummy_modality
]
logger.info(
"Encoder cache will be initialized with a budget of "
"%s tokens, and profiled with %s %s items of the "
"maximum feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
# Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_modality,
max_mm_items_per_batch,
)
# Run multimodal encoder.
# Isolate encoder graph from post-processing to minimize
# impact of recompilation until it's fixed.
start = time.perf_counter()
torch_xla.sync(wait=False)
dummy_encoder_outputs = self.model.embed_multimodal(
**batched_dummy_mm_inputs
)
torch_xla.sync(wait=False)
xm.wait_device_ops()
end = time.perf_counter()
logger.info(
"Multimodal Encoder profiling finished in %.2f [secs].",
end - start,
)
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
self._dummy_run(
num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req
)
if self.most_model_len is not None:
self._dummy_run(
num_tokens,
self.num_reqs_most_model_len,
self.num_blocks_per_most_len_req,
)
torch_xla.sync(wait=False)
xm.wait_device_ops()
self.encoder_cache.clear()
gc.collect()
def maybe_setup_cross_layer_kv_sharing(
self,
kv_caches: dict[str, torch.Tensor],
kv_cache_config: KVCacheConfig,
) -> None:
"""
Add layers that re-use KV cache to KV cache group of its target layer.
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
"""
if not self.shared_kv_cache_layers:
# No cross-layer KV sharing, return
return
add_kv_sharing_layers_to_kv_cache_groups(
self.shared_kv_cache_layers,
kv_cache_config.kv_cache_groups,
)
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
if len(kv_cache_config.kv_cache_groups) > 1:
raise NotImplementedError(
"Hybrid models with more than one KV cache type are not supported yet."
)
if (
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
!= self.block_size
):
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
kernel_block_sizes=[
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
],
)
# Verify dtype compatibility between block_table_cpu and input_batch
assert (
self.block_table_cpu.dtype
== self.input_batch.block_table[0].get_cpu_tensor().dtype
)
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in TPU."
)
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_size = kv_cache_sizes[layer_name]
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa
if isinstance(kv_cache_spec, AttentionSpec):
if self.use_spmd:
num_kv_heads = kv_cache_spec.num_kv_heads
assert self.original_parallel_config is not None
tp_size = self.original_parallel_config.tensor_parallel_size
# TODO: Handle kv cache duplication under SPMD mode.
assert num_kv_heads % tp_size == 0, (
f"num_kv_heads {num_kv_heads} must be divisible by "
f"tp_size {tp_size} under SPMD mode"
)
kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
dtype = kv_cache_spec.dtype
tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to(
self.device
)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError
# Set up cross-layer KV cache sharing if needed
self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config)
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches,
)
if self.use_spmd:
# Shard KV Cache
for cache in self.kv_caches:
xs.mark_sharding(cache, self.mesh, (None, "x", None, None))
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks)
def reset_dynamo_cache(self):
# NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs`
# since the compiled model object of the language backbone of a
# multimodal model needs to be extracted via `get_language_model`.
if self.model_config.is_multimodal_model:
compiled_model = self.model.get_language_model().model
else:
compiled_model = self.model.model
if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper):
logger.info("Clear dynamo cache and cached dynamo bytecode.")
torch._dynamo.eval_frame.remove_from_cache(
compiled_model.original_code_object()
)
# Reset the wrapper to re-initialize.
compiled_model.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(compiled_model)
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def select_hidden_states(self, hidden_states, indices_do_sample):
return hidden_states[indices_do_sample]
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor:
return self.model.compute_logits(sample_hidden_states)
# TODO: Under SPMD mode, sample_from_logits has correctness issue.
# Re-enable the torch.compile once the issue is fixed in torchxla.
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_logits(
self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids
return out_tokens
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def gather_logprobs(
self, logits: torch.Tensor, sampled_tokens: torch.Tensor
) -> LogprobsTensors:
"""
Gather the top_logprobs with corresponding tokens. Use a fixed number
of logprobs as an alternative to having multiple pre-compiled graphs.
Select the number of logprobs actually demanded by each request on CPU.
"""
logprobs = self.sampler.compute_logprobs(logits)
return self.sampler.gather_logprobs(
logprobs,
self.model_config.max_logprobs,
token_ids=sampled_tokens.squeeze(-1),
)
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def structured_decode(
self,
require_struct_decoding: torch.Tensor,
grammar_bitmask: torch.Tensor,
logits: torch.Tensor,
arange: torch.Tensor,
) -> torch.Tensor:
return torch.where(
require_struct_decoding,
self.apply_grammar_bitmask(logits, grammar_bitmask, arange),
logits,
)
def apply_grammar_bitmask(
self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor
):
assert logits.shape[0] == grammar_bitmask.shape[0]
logits_cloned = logits.clone()
for i in range(logits.shape[0]):
unpacked_bitmask = (
torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :])
& 1
) == 0
unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size]
logits_cloned[i] = logits_cloned[i].masked_fill(
unpacked_bitmask, -float("inf")
)
return logits_cloned
def embed_multimodal(self, *args, **kwargs):
return self.model.embed_multimodal(*args, **kwargs)
def embed_input_ids(self, *args, **kwargs):
return self.model.embed_input_ids(*args, **kwargs)
def prepare_structured_decoding_input(
self, logits: torch.Tensor, grammar_output: "GrammarOutput"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grammar_bitmask = grammar_output.grammar_bitmask
num_reqs, _ = logits.shape
# Reset pre-allocated tensors
self.grammar_bitmask_cpu.zero_()
self.require_structured_out_cpu.zero_()
cumulative_mask_idx = 0
for req_id in grammar_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index:
continue
batch_index = self.input_batch.req_id_to_index[req_id]
self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
grammar_bitmask[cumulative_mask_idx]
)
# It's not guaranteed that all requests in this batch require
# structured output, so create a bool tensor to represent
# the requests that need structured output.
self.require_structured_out_cpu[batch_index] = True
cumulative_mask_idx += 1
return (
self.require_structured_out_cpu[:num_reqs].to(logits.device),
self.grammar_bitmask_cpu[:num_reqs].to(logits.device),
self.structured_decode_arange.to(logits.device),
)
def _get_mm_dummy_batch(
self,
modality: str,
max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_model_len,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
return next(
grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
dummy_mm_items,
device=self.device,
pin_memory=self.pin_memory,
)
)
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
logger.info("Preparing request paddings:")
# assert min_req_size is power of 2
assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0
paddings: list = []
num = max(MIN_NUM_SEQS, min_req_size)
while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num):
paddings.append(num)
logger.info(" %d", num)
num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size)
return paddings
def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int:
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)
def _get_token_paddings(
min_token_size: int, max_token_size: int, padding_gap: int
) -> list[int]:
"""Generate a list of padding size, starting from min_token_size,
ending with a number that can cover max_token_size
If padding_gap == 0 then:
increase 2X each time (exponential)
else:
first increase the size to twice,
then increase the padding size by padding_gap.
"""
# assert min_token_size is power of 2
assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0
paddings = []
num = min_token_size
if padding_gap == 0:
logger.info("Using exponential token paddings:")
while True:
logger.info(" %d", num)
paddings.append(num)
if num >= max_token_size:
break
num *= 2
else:
logger.info("Using incremental token paddings:")
while num <= padding_gap:
logger.info(" %d", num)
paddings.append(num)
num *= 2
num //= 2
while num < max_token_size:
num += padding_gap
logger.info(" %d", num)
paddings.append(num)
return paddings
def _get_padded_token_len(paddings: list[int], x: int) -> int:
"""Return the first element in paddings list greater or equal to x."""
index = bisect.bisect_left(paddings, x)
assert index < len(paddings)
return paddings[index]
def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int
) -> int:
"""Calculates the padded number of KV cache update slices to avoid
recompilation."""
# NOTE(chengjiyao): let's say R_i is the token num for i-th request,
# so it occupies most 2 + R_i // page_size pages. The total maximum
# possible number of pages needed is sum(2 + R_i // page_size), which
# is <= 2 * max_num_reqs + sum(R_i) // page_size
# = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
padded_num_slices = min(padded_num_slices, num_tokens)
return padded_num_slices
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
"""Find the optimum number of slices to copy per Pallas program instance.
Increasing the number of slices copied in one instance of the kernel program
will increase HBM bandwidth utilization via more in-flight DMAs.
However, it will also use more VMEM, and experimentally, we observed
performance regression at 128 slices on v6e, likely due to running
out of scalar registers. Thus this function will limit the number of
slices to 64.
"""
# The default vmem_limit_bytes of a pallas kernel is 32MB. Here we
# calculate num_slices_per_block based on 16MB in case any register spills.
vmem_limit = 16 * 1024 * 1024
num_slices_per_block = vmem_limit // page_size_bytes
assert num_slices_per_block > 0, "Number of slices should be positive"
num_slices_per_block = prev_power_of_2(num_slices_per_block)
if num_slices_per_block > 64:
num_slices_per_block = 64
return num_slices_per_block
def replace_set_lora(model):
def _tpu_set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
# TODO: The integer index leads to a recompilation, but converting it
# to a tensor doesn't seem to work anymore. This might be fixed with a
# later release of torch_xla.
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
torch_xla.sync(wait=False)
def _tpu_reset_lora(self, index: int):
self._original_reset_lora(index)
torch_xla.sync(wait=False)
for _, module in model.named_modules():
if isinstance(module, BaseLayerWithLoRA):
module._original_set_lora = module.set_lora
module._original_reset_lora = module.reset_lora
module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign]
module, module.__class__
)
module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign]
module, module.__class__
)
......@@ -2,348 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from collections.abc import Callable
from typing import Any, TypeVar
from typing import TypeVar
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
)
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
_R = TypeVar("_R")
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
class TPUWorker:
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.use_spmd = envs.VLLM_XLA_USE_SPMD
self.original_parallel_config = None
if self.use_spmd:
# Under SPMD mode, distributed env is initialized as if there is
# only one worker/device.
self.original_parallel_config = self.parallel_config
self.parallel_config.tensor_parallel_size = 1
self.parallel_config.pipeline_parallel_size = 1
self.parallel_config.world_size = 1
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.cache_config.cache_dtype == "auto":
self.cache_dtype = self.model_config.dtype
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]
# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
# MP runtime is initialized.
self.profiler = None
self.profile_dir = None
if vllm_config.profiler_config.profiler == "torch" and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = vllm_config.profiler_config.torch_profiler_dir
logger.info(
"Profiling enabled. Traces will be saved to: %s", self.profile_dir
)
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
os.environ.get("LIBTPU_INIT_ARGS", "")
+ " --xla_tpu_force_1d_allreduce_at_chunk_count=1"
" --xla_jf_conv_input_fusion=False"
)
# --xla_jf_conv_input_fusion=False is used to improve the perf of
# quantized matmul.
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)
# Initialize the distributed environment.
self._init_tpu_worker_distributed_environment(
self.vllm_config, self.rank, self.distributed_init_method, self.local_rank
)
# Device initialization should happen after initializing
# the distributed runtime.
self.device = xm.xla_device()
self.device_config.device = self.device
# Set random seed.
set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# TODO (NickLucche) On gsm we compile 80+ graphs.
# Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size = self.parallel_config.world_size
rank = xr.global_ordinal()
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
# Consequently, changes in optimization flags, which affect compilation
# results, don't change the cache key. This can result in the wrong
# compilation being used. To prevent this, disabling the XLA compilation
# cache during development is recommended.We can disable it by
# `export VLLM_XLA_CACHE_PATH=`
if envs.VLLM_XLA_CACHE_PATH:
per_rank_path = os.path.join(
envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}"
)
xr.initialize_cache(per_rank_path, readonly=False)
# Init ModelRunner here, so that we have access to self.device.
self.model_runner = TPUModelRunner(
self.vllm_config, self.device, self.original_parallel_config
)
if rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
def determine_available_memory(self) -> int:
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, AttentionSpec):
dtype = layer_spec.dtype
# Use an empty tensor instead of `None` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device)
kv_caches[layer_name] = tpu_kv_cache
else:
raise NotImplementedError(
f"Unsupported KV cache spec '{type(layer_spec)}'"
)
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches,
)
# `max_num_tokens >= max_num_batched_tokens` due to padding.
with self.model_runner.maybe_setup_dummy_loras(self.lora_config):
self.model_runner.profile_run(self.model_runner.max_num_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
with set_current_vllm_config(self.vllm_config):
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
if self.use_spmd:
# This is a workaround for the TPU SPMD mode. The get_memory_info
# API doesn't work with SPMD mode in PyTorch/XLA.
# TODO: use xm.get_memory_info for SPMD once it's supported in
# PyTorch/XLA.
import tpu_info
chip_type, _ = tpu_info.device.get_local_chips()
device_usage = tpu_info.metrics.get_chip_usage(chip_type)
total_memory_size = device_usage[0].total_memory
current_mem = device_usage[0].memory_usage
else:
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(
total_memory_size * self.cache_config.gpu_memory_utilization
)
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
head_size = self.model_config.get_head_size()
if head_size > 0:
padded_head_size = (
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
)
if padded_head_size != head_size:
logger.warning_once("head size is padded to %d", padded_head_size)
# We adjust the usable memory size for the KV cache to prevent OOM
# errors, even after padding the head_size.
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
return self.model_runner.execute_model(scheduler_output)
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profile_dir is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
if self.profiler is None:
self.profiler = xp.start_server(9012)
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)
def load_model(self) -> None:
self.model_runner.load_model()
def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)
def reload_weights(self) -> None:
self.model_runner.reload_weights()
def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None:
# worker will always be healthy as long as it's running.
return
def _init_tpu_worker_distributed_environment(
self,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: str | None = None,
local_rank: int = -1,
) -> None:
"""Initialize the distributed environment."""
if self.use_spmd:
xr.use_spmd()
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
parallel_config = vllm_config.parallel_config
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method or "env://",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
)
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()
def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
"""Apply a function on the model inside this worker."""
return fn(self.get_model())
# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin
if USE_TPU_INFERENCE:
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment