Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
......@@ -654,7 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
......@@ -673,6 +672,7 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
......@@ -692,6 +692,11 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
class FlashInferBackend(AttentionBackend):
......@@ -957,7 +955,6 @@ class FlashInferImpl(AttentionImpl):
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if attn_type != AttentionType.DECODER:
......@@ -975,8 +972,14 @@ class FlashInferImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashInferImpl")
# TODO: directly write to output tensor
num_heads: int = self.num_heads
head_size: int = self.head_size
......
......@@ -148,7 +148,6 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if self.prefill_impl == 'fsdpa':
......@@ -181,6 +180,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -193,6 +193,11 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape
......
......@@ -145,7 +145,6 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.sliding_window is not None)
if logits_soft_cap is None:
......@@ -192,6 +191,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
......@@ -206,6 +206,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for IpexAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
......
......@@ -1334,11 +1334,17 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")
if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size
......
......@@ -121,9 +121,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
......@@ -172,6 +171,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
......@@ -187,6 +187,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for PallasAttentionImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
......
......@@ -17,6 +17,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention
......@@ -37,11 +38,11 @@ def is_rocm_aiter_paged_attn_enabled() -> bool:
@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
......@@ -527,7 +528,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.paged_attn_module = _get_paged_attn_module()
......@@ -584,6 +584,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
logger.debug("Using naive (SDPA) attention in ROCmBackend")
self.aiter_kv_scales_initialized = False
self.force_fp8_attention = (
get_current_vllm_config() is not None
and get_current_vllm_config().model_config.override_attention_dtype
== "fp8")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
......@@ -593,6 +597,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: tuple[int, int]):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == (-1, -1) # per-tensor
# Only supported in the Triton backend
return False
def forward(
self,
layer: AttentionLayer,
......@@ -602,6 +615,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
......@@ -655,6 +669,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
......@@ -770,9 +789,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and self.kv_cache_dtype == "fp8")
and (self.kv_cache_dtype == "fp8"
or self.force_fp8_attention))
full_scales = (
layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(),
......@@ -791,6 +813,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
output_scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
......@@ -880,7 +903,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
......@@ -914,9 +937,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
output_scale,
)
else:
output[num_prefill_tokens:] = paged_attn.forward_decode(
# PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
......@@ -937,6 +968,14 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
......
......@@ -65,7 +65,7 @@ class TorchSDPABackend(AttentionBackend):
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
raise NotImplementedError("Swap is not supported in TorchSDPABackend.")
@staticmethod
def copy_blocks(
......@@ -433,7 +433,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
......@@ -459,6 +458,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
......@@ -473,6 +473,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl")
# For warming-up
if attn_metadata is None:
......
......@@ -373,7 +373,7 @@ class CommonAttentionState(AttentionState):
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
self._add_additional_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
......@@ -427,7 +427,7 @@ class CommonAttentionState(AttentionState):
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.num_encoder_tokens = 0
def _add_additonal_input_buffers_for_enc_dec_model(
def _add_additional_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]):
"""
Saves additional input buffers specific to the encoder-decoder model
......
......@@ -415,7 +415,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
supported_head_sizes = PagedAttention.get_supported_head_sizes()
......@@ -435,6 +434,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
......@@ -487,6 +487,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
......
......@@ -80,6 +80,9 @@ class Attention(nn.Module):
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
......@@ -206,7 +209,7 @@ class Attention(nn.Module):
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.empty(output_shape,
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]
......@@ -291,7 +294,9 @@ class MultiHeadAttention(nn.Module):
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \
f"divisible by num_kv_heads ({self.num_kv_heads})"
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
dtype = torch.get_default_dtype()
......@@ -301,12 +306,17 @@ class MultiHeadAttention(nn.Module):
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS
if current_platform.is_rocm():
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:
if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
_Backend.FLEX_ATTENTION):
backend = _Backend.XFORMERS
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
} else _Backend.TORCH_SDPA
def forward(
self,
......@@ -430,6 +440,7 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
......@@ -444,7 +455,8 @@ def unified_attention_with_output(
value,
kv_cache,
attn_metadata,
output=output)
output=output,
output_scale=output_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......@@ -455,6 +467,7 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
) -> None:
return
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
......@@ -29,7 +29,7 @@ class _PagedAttention:
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(
......@@ -120,7 +120,7 @@ class _PagedAttention:
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
......
......@@ -8,9 +8,7 @@ import torch
from neuronxcc import nki
from neuronxcc.nki.language import par_dim
def ceil_div(a, b):
return (a + b - 1) // b
from vllm.utils import cdiv
def is_power_of_2(x):
......@@ -35,11 +33,10 @@ def load_block_tables(block_tables_hbm, num_tiles, num_blocks_per_tile):
(num_tiles, num_blocks_per_tile))
block_tables_sbuf = nl.zeros(
(ceil_div(num_tiles,
B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
(cdiv(num_tiles, B_P_SIZE), par_dim(B_P_SIZE), num_blocks_per_tile),
dtype=nl.int32,
)
for i in nl.affine_range(ceil_div(num_tiles, B_P_SIZE)):
for i in nl.affine_range(cdiv(num_tiles, B_P_SIZE)):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(num_blocks_per_tile)[None, :]
block_tables_sbuf[i, i_p, i_f] = nl.load(
......@@ -83,7 +80,7 @@ def transform_block_tables_for_indirect_load(
assert is_power_of_2(
num_blocks_per_tile), f"{num_blocks_per_tile=} is not power of 2"
num_loads = ceil_div(num_blocks_per_tile, B_P_SIZE)
num_loads = cdiv(num_blocks_per_tile, B_P_SIZE)
block_tables_transposed = nl.ndarray(
(
num_loads,
......@@ -165,7 +162,7 @@ def load_kv_tile_from_cache(
equivalent to (par_dim(B_P_SIZE), seqlen_kv // B_P_SIZE * B_D_SIZE)
"""
# load key cache
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
for load_idx in nl.affine_range(num_loads):
i_p = nl.arange(B_P_SIZE)[:, None]
i_f = nl.arange(tiled_block_size * B_D_SIZE)[None, :]
......@@ -605,7 +602,7 @@ def flash_paged_attention(
)
for large_k_tile_idx in nl.sequential_range(0, num_large_k_tile):
num_loads = ceil_div(num_blocks_per_large_tile, B_P_SIZE)
num_loads = cdiv(num_blocks_per_large_tile, B_P_SIZE)
cur_k_tile = nl.ndarray(
(par_dim(B_D_SIZE), LARGE_TILE_SZ),
dtype=kernel_dtype,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from vllm.utils import cdiv
def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
# new_kv_start, slice_len)
# Input
new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim]
kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads,
# head_dim]
# Output
_, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
# Scratch
scratch, # [num_slices_per_block, page_size, num_combined_kv_heads,
# head_dim]
sem,
):
async_copies = []
block_idx = pl.program_id(0)
num_slices_per_block = scratch.shape[0]
# Copy from new_kv_hbm_ref to scratch
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
new_kv_start = slices_ref[1, offset_i]
length = slices_ref[2, offset_i]
async_copy = pltpu.make_async_copy(
new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...],
scratch.at[i, pl.ds(0, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
# Copy from scratch to kv_cache_hbm_ref
async_copies.clear()
for i in range(num_slices_per_block):
offset_i = i + block_idx * num_slices_per_block
kv_cache_start = slices_ref[0, offset_i]
length = slices_ref[2, offset_i]
async_copy = pltpu.make_async_copy(
scratch.at[i, pl.ds(0, length), ...],
kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...],
sem,
)
async_copy.start()
async_copies.append(async_copy)
for async_copy in async_copies:
async_copy.wait()
@functools.partial(
jax.jit,
static_argnames=["page_size", "num_slices_per_block"],
)
def kv_cache_update(
new_kv: jax.Array, # [total_num_token, num_combined_kv_heads, head_dim]
slices: jax.
Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len)
kv_cache: jax.
Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim]
num_kv_update_slices: jax.Array, # [1]
*,
page_size: int = 32,
num_slices_per_block: int = 8,
):
assert slices.shape[1] % num_slices_per_block == 0
_, num_combined_kv_heads, head_dim = new_kv.shape
assert kv_cache.shape[1] == num_combined_kv_heads
assert kv_cache.shape[2] == head_dim
assert head_dim % 128 == 0
# TODO: Add dynamic check to make sure that the all the slice lengths are
# smaller or equal to page_size
in_specs = [
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
]
out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)]
out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)]
scalar_prefetches = [slices]
scratch = pltpu.VMEM(
(num_slices_per_block, page_size, num_combined_kv_heads, head_dim),
new_kv.dtype,
)
scratch_shapes = [
scratch,
pltpu.SemaphoreType.DMA,
]
kernel = pl.pallas_call(
_kv_cache_update_kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=(cdiv(num_kv_update_slices[0], num_slices_per_block), ),
scratch_shapes=scratch_shapes,
),
out_shape=out_shape,
input_output_aliases={len(scalar_prefetches) + 1: 0},
)
return kernel(*scalar_prefetches, new_kv, kv_cache)[0]
......@@ -25,9 +25,14 @@ Not currently supported:
import torch
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx1x
from vllm.triton_utils import tl, triton
# Avoid misleading ROCm warning.
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx1x
else:
on_gfx1x = lambda *args, **kwargs: False
torch_dtype: tl.constexpr = torch.float16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment