Unverified Commit 288cc6c2 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] MLA with chunked prefill (#12639)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarPatrick Horn <patrick.horn@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 900edbfa
...@@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, ...@@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype); const double scale, const std::string& kv_cache_dtype);
void gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "cuda_utils.h"
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
...@@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
} }
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t>
__global__ void gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
const int32_t split_start = split * split_blocks;
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
const bool is_active_split = (split_start < tot_blocks);
const bool is_last_split = (split_end == tot_blocks);
if (!is_active_split) return;
int32_t full_blocks_end = split_end;
int32_t partial_block_size = 0;
// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
// page_size)
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = 0;
if (seq_starts != nullptr) {
offset = seq_starts[bid] / block_size;
}
const int32_t* batch_block_table = block_table + batch_offset + offset;
// Adjust dst pointer based on the cumulative sequence lengths.
dst += seq_start * dst_entry_stride;
if (is_last_split) {
partial_block_size = seq_len % block_size;
if (partial_block_size) full_blocks_end -= 1;
}
auto copy_entry = [&](const scalar_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
};
for (int pid = split_start; pid < full_blocks_end; ++pid) {
auto block_id = batch_block_table[pid];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
for (int eid = 0; eid < block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
if (partial_block_size) {
auto block_id = batch_block_table[full_blocks_end];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
for (int eid = 0; eid < partial_block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
}
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size.
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(1024);
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");
const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
if (dtype_bits == 32) {
CALL_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
...@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { ...@@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num; if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
} }
template <typename T>
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
return (a + b - 1) / b;
}
\ No newline at end of file
...@@ -2,10 +2,14 @@ ...@@ -2,10 +2,14 @@
#include <stdio.h> #include <stdio.h>
#if defined(__CUDACC__) || defined(_NVHPC_CUDA) #if defined(__HIPCC__)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ #define HOST_DEVICE_INLINE __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__ #define DEVICE_INLINE __device__
#define HOST_INLINE __forceinline__ __host__ #define HOST_INLINE __host__
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
#define DEVICE_INLINE __device__ __forceinline__
#define HOST_INLINE __host__ __forceinline__
#else #else
#define HOST_DEVICE_INLINE inline #define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline #define DEVICE_INLINE inline
...@@ -25,3 +29,13 @@ ...@@ -25,3 +29,13 @@
int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
namespace cuda_utils {
template <typename T>
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
ceil_div(T a, T b) {
return (a + b - 1) / b;
}
}; // namespace cuda_utils
\ No newline at end of file
#include <cudaTypedefs.h> #include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp" #include "c3x/scaled_mm_kernels.hpp"
#include "core/math.hpp" #include "cuda_utils.h"
/* /*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for This file defines quantized GEMM operations using the CUTLASS 3.x API, for
...@@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, ...@@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
auto make_group_shape = [](torch::Tensor const& x, auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape { torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))}; return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
}; };
GroupShape a_scale_group_shape = make_group_shape(a, a_scales); GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
......
...@@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -493,6 +493,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"str kv_cache_dtype) -> ()"); "str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
// Gather cache blocks from src_cache to dst.
cache_ops.def(
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
......
...@@ -682,8 +682,6 @@ def test_swap_blocks_mla( ...@@ -682,8 +682,6 @@ def test_swap_blocks_mla(
torch.ops._C_cache_ops.swap_blocks, torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor), (src_cache, dst_cache, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(kv_lora_rank == KV_LORA_RANKS[0]
and qk_rope_head_dim == QK_ROPE_HEAD_DIMS[0]),
) )
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
...@@ -694,3 +692,76 @@ def test_swap_blocks_mla( ...@@ -694,3 +692,76 @@ def test_swap_blocks_mla(
dst_cache[dst].cpu(), dst_cache[dst].cpu(),
msg=f"Block {src} from src should have been swapped to block " msg=f"Block {src} from src should have been swapped to block "
f"{dst} in dst_cache.") f"{dst} in dst_cache.")
@pytest.mark.parametrize("kv_lora_rank", [512])
@pytest.mark.parametrize("qk_rope_head_dim", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("max_seq_len", [512])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype",
["auto"]) # You can also test "fp8" if needed.
@pytest.mark.parametrize("align_cache", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
num_blocks, max_seq_len, batch_size, dtype,
kv_cache_dtype, align_cache, device):
entry_size = kv_lora_rank + qk_rope_head_dim
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device, align_cache)
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
seq_len_tensor = torch.randint(0,
max_seq_len + 1, (batch_size, ),
device=device)
total_tokens = seq_len_tensor.sum()
cu_seq_lens = torch.empty((batch_size + 1),
dtype=torch.int32,
device=device)
cu_seq_lens[0] = 0
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
print("seq_len_tensor", seq_len_tensor)
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
block_table = torch.empty((batch_size, num_blocks),
dtype=torch.int32,
device=device)
for b in range(batch_size):
perm = torch.randperm(num_blocks, device=device)
block_table[b, :] = perm
dst = torch.zeros((total_tokens, entry_size),
dtype=src_cache.dtype,
device=device)
expected_batches = []
for b in range(batch_size):
s = seq_len_tensor[b]
if s == 0:
continue
tot = tot_blocks_tensor[b]
blocks = block_table[b, :tot].tolist()
gathered_rows = []
for i in range(tot - 1):
gathered_rows.append(src_cache[blocks[i]])
remaining = s - (tot - 1) * block_size
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
batch_expected = torch.cat(gathered_rows, dim=0)
expected_batches.append(batch_expected)
expected = torch.cat(expected_batches, dim=0)
opcheck(
torch.ops._C_cache_ops.gather_cache,
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
...@@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor, ...@@ -1099,6 +1099,16 @@ def convert_fp8(output: torch.Tensor,
torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
def gather_cache(src_cache: torch.Tensor,
dst: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def get_device_attribute(attribute: int, device: int) -> int: def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
......
...@@ -4,16 +4,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, ...@@ -4,16 +4,12 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionState, AttentionType) AttentionState, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
__all__ = [ __all__ = [
"Attention", "Attention", "AttentionBackend", "AttentionMetadata", "AttentionType",
"AttentionBackend", "AttentionMetadataBuilder", "Attention", "AttentionState",
"AttentionMetadata", "get_attn_backend", "get_flash_attn_version"
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
] ]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import triton
import triton.language as tl
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: Optional[torch.Tensor] = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
output_lse,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
output_lse is not None,
)
@triton.jit
def merge_attn_states_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
output_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
out_se = (tl.exp(p_lse) + tl.exp(s_lse))
if OUTPUT_LSE:
out_lse = tl.log(out_se) + max_lse
tl.store(output_lse + head_idx * num_tokens + token_idx, out_lse)
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = tl.exp(p_lse) / out_se
s_scale = tl.exp(s_lse) / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask)
...@@ -3332,19 +3332,6 @@ class VllmConfig: ...@@ -3332,19 +3332,6 @@ class VllmConfig:
current_platform.check_and_update_config(self) current_platform.check_and_update_config(self)
# If MLA is enabled, force disable chunked prefill and prefix caching
if self.model_config and self.model_config.use_mla:
logger.info("MLA is enabled; forcing chunked prefill and prefix "
"caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
......
...@@ -1170,9 +1170,9 @@ class EngineArgs: ...@@ -1170,9 +1170,9 @@ class EngineArgs:
# long context (> 32K) models. This is to avoid OOM errors in the # long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase. # initial memory profiling phase.
# For multimodal models, chunked prefill is disabled by default in # For multimodal models and models with MLA, chunked prefill is
# V0, but enabled by design in V1 # disabled by default in V0, but enabled by design in V1
if model_config.is_multimodal_model: if model_config.is_multimodal_model or model_config.use_mla:
self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) self.enable_chunked_prefill = bool(envs.VLLM_USE_V1)
elif use_long_context: elif use_long_context:
...@@ -1207,7 +1207,6 @@ class EngineArgs: ...@@ -1207,7 +1207,6 @@ class EngineArgs:
msg = "Chunked prefill is not supported for pooling models" msg = "Chunked prefill is not supported for pooling models"
raise ValueError(msg) raise ValueError(msg)
speculative_config = SpeculativeConfig.maybe_create_spec_config( speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
......
...@@ -162,6 +162,9 @@ def _per_token_group_quant_fp8( ...@@ -162,6 +162,9 @@ def _per_token_group_quant_fp8(
y_q_ptr, y_q_ptr,
y_s_ptr, y_s_ptr,
group_size, group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for float8
...@@ -174,9 +177,14 @@ def _per_token_group_quant_fp8( ...@@ -174,9 +177,14 @@ def _per_token_group_quant_fp8(
quantization on a tensor. quantization on a tensor.
This function converts the tensor values into float8 values. This function converts the tensor values into float8 values.
""" """
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
y_ptr += g_id * group_size row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size y_q_ptr += g_id * group_size
y_s_ptr += g_id y_s_ptr += g_id
...@@ -202,6 +210,7 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -202,6 +210,7 @@ def _per_token_group_quant_fp8_colmajor(
group_size, group_size,
# Num columns of y # Num columns of y
y_num_columns, y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s # Stride from one column to the next of y_s
y_s_col_stride, y_s_col_stride,
# Avoid to divide zero # Avoid to divide zero
...@@ -216,9 +225,14 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -216,9 +225,14 @@ def _per_token_group_quant_fp8_colmajor(
quantization on a tensor. quantization on a tensor.
This function converts the tensor values into float8 values. This function converts the tensor values into float8 values.
""" """
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute. # Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0) g_id = tl.program_id(0)
y_ptr += g_id * group_size row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index # Convert g_id the flattened block coordinate to 2D so we can index
...@@ -267,7 +281,7 @@ def per_token_group_quant_fp8( ...@@ -267,7 +281,7 @@ def per_token_group_quant_fp8(
assert (x.shape[-1] % group_size == 0), ( assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}") f"by `group_size` {group_size}")
assert x.is_contiguous(), "`x` must be contiguous" assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
fp8_min = finfo.min fp8_min = finfo.min
...@@ -295,6 +309,7 @@ def per_token_group_quant_fp8( ...@@ -295,6 +309,7 @@ def per_token_group_quant_fp8(
x_s, x_s,
group_size, group_size,
x.shape[1], x.shape[1],
x.stride(0),
x_s.stride(1), x_s.stride(1),
eps, eps,
fp8_min=fp8_min, fp8_min=fp8_min,
...@@ -309,6 +324,8 @@ def per_token_group_quant_fp8( ...@@ -309,6 +324,8 @@ def per_token_group_quant_fp8(
x_q, x_q,
x_s, x_s,
group_size, group_size,
x.shape[1],
x.stride(0),
eps, eps,
fp8_min=fp8_min, fp8_min=fp8_min,
fp8_max=fp8_max, fp8_max=fp8_max,
......
...@@ -565,6 +565,10 @@ def round_up(x: int, y: int) -> int: ...@@ -565,6 +565,10 @@ def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y return ((x + y - 1) // y) * y
def round_down(x: int, y: int) -> int:
return (x // y) * y
def _generate_random_fp8( def _generate_random_fp8(
tensor: torch.Tensor, tensor: torch.Tensor,
low: float, low: float,
......
...@@ -5,12 +5,11 @@ from typing import Any, Dict, List, Optional, Tuple, Type ...@@ -5,12 +5,11 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import numpy as np import numpy as np
import torch import torch
import triton
import triton.language as tl
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
...@@ -372,70 +371,4 @@ def cascade_attention( ...@@ -372,70 +371,4 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse) suffix_lse)
\ No newline at end of file
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
prefix_output,
prefix_lse,
suffix_output,
suffix_lse,
head_size,
padded_head_size,
)
@triton.jit
def merge_attn_states_kernel(
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
head_idx = tl.program_id(1)
num_heads = tl.num_programs(1)
p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx)
s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx)
max_lse = tl.maximum(p_lse, s_lse)
p_lse = p_lse - max_lse
s_lse = s_lse - max_lse
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(prefix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)
s_out = tl.load(suffix_output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
mask=head_mask)
# NOTE(woosuk): Be careful with the numerical stability.
# We should compute the scale first, and then multiply it with the output.
# Do not multiply the output with tl.exp(p_lse) or tl.exp(s_lse) directly.
p_scale = tl.exp(p_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
s_scale = tl.exp(s_lse) / (tl.exp(p_lse) + tl.exp(s_lse))
out = p_out * p_scale + s_out * s_scale
tl.store(output + token_idx * num_heads * HEAD_SIZE +
head_idx * HEAD_SIZE + head_arange,
out,
mask=head_mask)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment