Commit e7096898 authored by liuchy5's avatar liuchy5
Browse files

Dsa supported.

parent 1ce0a9a2
......@@ -80,7 +80,12 @@ void indexer_k_quant_and_cache(
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Indexer K cache function
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt);
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
......
......@@ -600,6 +600,52 @@ __global__ void indexer_k_quant_and_cache_kernel(
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
template <typename scalar_t, typename cache_t>
__global__ void indexer_k_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int cache_block_size, // cache block size
const int cache_stride // stride for each token in kv_cache
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
float val = static_cast<float>(k_val_ptr[i]);
if constexpr (std::is_same<cache_t, at::Half>::value ||
std::is_same<cache_t, __half>::value) {
kv_cache[dst_offset + i] = __float2half(val);
} else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
std::is_same<cache_t, __nv_bfloat16>::value) {
__hip_bfloat16 bf16_val = __float2bfloat16(val);
kv_cache[dst_offset + i] = *reinterpret_cast<at::BFloat16*>(&bf16_val);
} else if constexpr (std::is_same<cache_t, float>::value) {
kv_cache[dst_offset + i] = val;
} else {
kv_cache[dst_offset + i] = static_cast<cache_t>(val);
}
}
}
template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
......@@ -1504,3 +1550,64 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
}
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + vec_size - 1) / vec_size);
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
k.scalar_type(), "indexer_k_cache", ([&] {
using k_t = scalar_t;
auto kv_cache_type = kv_cache.scalar_type();
if (kv_cache_type == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<float>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::Half>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::BFloat16) {
vllm::indexer_k_cache_kernel<k_t, at::BFloat16>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::BFloat16>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
}
......@@ -778,6 +778,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \
......
......@@ -816,6 +816,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
cache_ops.def(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_cache", torch::kCUDA,
&indexer_k_cache);
cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
......
......@@ -2873,12 +2873,10 @@ def cp_gather_indexer_k_quant_cache(
)
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
quant_block_size,
torch.ops._C_cache_ops.indexer_k_cache(k, kv_cache, slot_mapping,
kv_cache_dtype)
......
......@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if (
getattr(layer, "_marlin_w16a16_moe_enabled", False)
......
......@@ -295,8 +295,9 @@ class SparseAttnIndexer(CustomOp):
k: torch.Tensor,
weights: torch.Tensor,
):
if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
#if rocm_aiter_ops.is_enabled():
if current_platform.is_rocm():
return rocm_aiter_sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
......
......@@ -712,7 +712,8 @@ class Indexer(nn.Module):
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
else:
q_fp8 = q
weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = (
......
# SPDX-License-Identifier: Apache-2.0
# -lSPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
......@@ -261,15 +261,15 @@ class RocmPlatform(Platform):
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
# if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
# raise ValueError(
# "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
# )
# assert block_size == 1, (
# "Sparse MLA backend on ROCm only supports block size 1 for now."
# )
logger.info_once("Using Sparse MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
if attn_selector_config.use_mla:
# if attn_selector_config.use_sparse:
......
......@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend):
exclude_from_block_size_selection = True
@staticmethod
def get_name() -> str:
return "DEEPSEEK_V32_INDEXER"
......@@ -323,7 +324,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported():
#if is_deep_gemm_supported():
if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
......
......@@ -31,7 +31,10 @@ else:
if current_platform.is_rocm():
import flash_mla.cuda as flash_mla_cuda
#from vllm.v1.attention.ops import flashmla
#flash_mla_cuda = flashmla.flash_mla_cuda
from flash_mla.flash_mla_interface import flash_mla_cuda
#import flash_mla.cuda as flash_mla_cuda
_flashmla_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True
......
......@@ -562,7 +562,6 @@ def rocm_aiter_sparse_attn_indexer(
chunk.block_table,
chunk.cu_seq_lens,
)
logits = rocm_fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
......@@ -570,6 +569,24 @@ def rocm_aiter_sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
#k_fp8 = torch.empty(
# [chunk.total_seq_lens, head_dim],
# device=k.device,
# dtype=k.dtype,
#)
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[
......
......@@ -5536,7 +5536,9 @@ class GPUModelRunner(
Raises:
ValueError: If no valid block size found
"""
#exclude indexer backend
def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
return not getattr(backend, "exclude_from_block_size_selection", False)
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
......@@ -5557,8 +5559,11 @@ class GPUModelRunner(
if not is_supported:
return False
return True
backends = [group.backend for group in attn_groups]
all_backends = [group.backend for group in attn_groups]
backends = [
b for b in all_backends
if _participates_in_block_size_selection(b)
]
# Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment