Commit e7096898 authored by liuchy5's avatar liuchy5
Browse files

Dsa supported.

parent 1ce0a9a2
...@@ -80,7 +80,12 @@ void indexer_k_quant_and_cache( ...@@ -80,7 +80,12 @@ void indexer_k_quant_and_cache(
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size int64_t quant_block_size, // quantization block size
const std::string& scale_fmt); const std::string& scale_fmt);
// Indexer K cache function
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt);
// Extract function to gather quantized K cache // Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache( void cp_gather_indexer_k_quant_cache(
......
...@@ -600,6 +600,52 @@ __global__ void indexer_k_quant_and_cache_kernel( ...@@ -600,6 +600,52 @@ __global__ void indexer_k_quant_and_cache_kernel(
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale; reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
} }
} }
template <typename scalar_t, typename cache_t>
__global__ void indexer_k_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int cache_block_size, // cache block size
const int cache_stride // stride for each token in kv_cache
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
float val = static_cast<float>(k_val_ptr[i]);
if constexpr (std::is_same<cache_t, at::Half>::value ||
std::is_same<cache_t, __half>::value) {
kv_cache[dst_offset + i] = __float2half(val);
} else if constexpr (std::is_same<cache_t, at::BFloat16>::value ||
std::is_same<cache_t, __nv_bfloat16>::value) {
__hip_bfloat16 bf16_val = __float2bfloat16(val);
kv_cache[dst_offset + i] = *reinterpret_cast<at::BFloat16*>(&bf16_val);
} else if constexpr (std::is_same<cache_t, float>::value) {
kv_cache[dst_offset + i] = val;
} else {
kv_cache[dst_offset + i] = static_cast<cache_t>(val);
}
}
}
template <int BLOCK_Y_SIZE> template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel( __global__ void cp_gather_indexer_k_quant_cache_kernel(
...@@ -1504,3 +1550,64 @@ void cp_gather_indexer_k_quant_cache( ...@@ -1504,3 +1550,64 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
} }
} }
void indexer_k_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + vec_size - 1) / vec_size);
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
k.scalar_type(), "indexer_k_cache", ([&] {
using k_t = scalar_t;
auto kv_cache_type = kv_cache.scalar_type();
if (kv_cache_type == at::ScalarType::Float) {
vllm::indexer_k_cache_kernel<k_t, float>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<float>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::Half) {
vllm::indexer_k_cache_kernel<k_t, at::Half>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::Half>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else if (kv_cache_type == at::ScalarType::BFloat16) {
vllm::indexer_k_cache_kernel<k_t, at::BFloat16>
<<<grid, block, 0, stream>>>(
k.data_ptr<k_t>(),
kv_cache.data_ptr<at::BFloat16>(),
slot_mapping.data_ptr<int64_t>(),
head_dim,
cache_block_size,
cache_stride);
} else {
TORCH_CHECK(false, "Unsupported kv_cache dtype: ", kv_cache.dtype());
}
}));
}
...@@ -778,6 +778,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -778,6 +778,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \ "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \ } \
......
...@@ -816,6 +816,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -816,6 +816,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"int quant_block_size, str kv_cache_dtype) -> ()"); "int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache); &indexer_k_quant_and_cache);
cache_ops.def(
"indexer_k_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_cache", torch::kCUDA,
&indexer_k_cache);
cache_ops.def( cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! " "cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
......
...@@ -2873,12 +2873,10 @@ def cp_gather_indexer_k_quant_cache( ...@@ -2873,12 +2873,10 @@ def cp_gather_indexer_k_quant_cache(
) )
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor, def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
quant_block_size: int,
kv_cache_dtype: str) -> None: kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, torch.ops._C_cache_ops.indexer_k_cache(k, kv_cache, slot_mapping,
quant_block_size,
kv_cache_dtype) kv_cache_dtype)
......
...@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -396,6 +396,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if ( if (
getattr(layer, "_marlin_w16a16_moe_enabled", False) getattr(layer, "_marlin_w16a16_moe_enabled", False)
......
...@@ -295,8 +295,9 @@ class SparseAttnIndexer(CustomOp): ...@@ -295,8 +295,9 @@ class SparseAttnIndexer(CustomOp):
k: torch.Tensor, k: torch.Tensor,
weights: torch.Tensor, weights: torch.Tensor,
): ):
if rocm_aiter_ops.is_enabled(): #if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( if current_platform.is_rocm():
return rocm_aiter_sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, self.k_cache.prefix,
self.k_cache.kv_cache[0], self.k_cache.kv_cache[0],
......
...@@ -712,7 +712,8 @@ class Indexer(nn.Module): ...@@ -712,7 +712,8 @@ class Indexer(nn.Module):
) )
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1) q_scale = q_scale.view(-1, self.n_head, 1)
else:
q_fp8 = q
weights, _ = self.weights_proj(hidden_states) weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = ( weights = (
......
# SPDX-License-Identifier: Apache-2.0 # -lSPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
...@@ -261,15 +261,15 @@ class RocmPlatform(Platform): ...@@ -261,15 +261,15 @@ class RocmPlatform(Platform):
kv_cache_dtype = attn_selector_config.kv_cache_dtype kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse: if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"): # if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError( # raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." # "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
) # )
assert block_size == 1, ( # assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now." # "Sparse MLA backend on ROCm only supports block size 1 for now."
) # )
logger.info_once("Using Sparse MLA backend.") logger.info_once("Using Sparse MLA backend.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path() return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
if attn_selector_config.use_mla: if attn_selector_config.use_mla:
# if attn_selector_config.use_sparse: # if attn_selector_config.use_sparse:
......
...@@ -27,6 +27,7 @@ logger = init_logger(__name__) ...@@ -27,6 +27,7 @@ logger = init_logger(__name__)
class DeepseekV32IndexerBackend(AttentionBackend): class DeepseekV32IndexerBackend(AttentionBackend):
exclude_from_block_size_selection = True
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "DEEPSEEK_V32_INDEXER" return "DEEPSEEK_V32_INDEXER"
...@@ -323,7 +324,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -323,7 +324,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes] seq_lens = common_attn_metadata.seq_lens[:num_decodes]
if is_deep_gemm_supported(): #if is_deep_gemm_supported():
if current_platform.is_rocm(): if current_platform.is_rocm():
self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata( self.scheduler_metadata_buffer= gemmopt.get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms seq_lens, self.kv_cache_spec.block_size, self.num_sms
......
...@@ -31,7 +31,10 @@ else: ...@@ -31,7 +31,10 @@ else:
if current_platform.is_rocm(): if current_platform.is_rocm():
import flash_mla.cuda as flash_mla_cuda #from vllm.v1.attention.ops import flashmla
#flash_mla_cuda = flashmla.flash_mla_cuda
from flash_mla.flash_mla_interface import flash_mla_cuda
#import flash_mla.cuda as flash_mla_cuda
_flashmla_C_AVAILABLE = True _flashmla_C_AVAILABLE = True
_flashmla_extension_C_AVAILABLE = True _flashmla_extension_C_AVAILABLE = True
......
...@@ -562,7 +562,6 @@ def rocm_aiter_sparse_attn_indexer( ...@@ -562,7 +562,6 @@ def rocm_aiter_sparse_attn_indexer(
chunk.block_table, chunk.block_table,
chunk.cu_seq_lens, chunk.cu_seq_lens,
) )
logits = rocm_fp8_mqa_logits( logits = rocm_fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end], q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)), (k_fp8, k_scale.view(torch.float32)),
...@@ -570,6 +569,24 @@ def rocm_aiter_sparse_attn_indexer( ...@@ -570,6 +569,24 @@ def rocm_aiter_sparse_attn_indexer(
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
) )
else:
#k_fp8 = torch.empty(
# [chunk.total_seq_lens, head_dim],
# device=k.device,
# dtype=k.dtype,
#)
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
num_rows = logits.shape[0] num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048" assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[ topk_indices = topk_indices_buffer[
......
...@@ -5536,7 +5536,9 @@ class GPUModelRunner( ...@@ -5536,7 +5536,9 @@ class GPUModelRunner(
Raises: Raises:
ValueError: If no valid block size found ValueError: If no valid block size found
""" """
#exclude indexer backend
def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
return not getattr(backend, "exclude_from_block_size_selection", False)
def block_size_is_supported( def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int backends: list[type[AttentionBackend]], block_size: int
) -> bool: ) -> bool:
...@@ -5557,8 +5559,11 @@ class GPUModelRunner( ...@@ -5557,8 +5559,11 @@ class GPUModelRunner(
if not is_supported: if not is_supported:
return False return False
return True return True
all_backends = [group.backend for group in attn_groups]
backends = [group.backend for group in attn_groups] backends = [
b for b in all_backends
if _participates_in_block_size_selection(b)
]
# Case 1: if the block_size of kv cache manager is supported by all backends, # Case 1: if the block_size of kv cache manager is supported by all backends,
# return it directly # return it directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment