Commit 25ec6a34 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mqa_logits and paged_mqa_logits

parent 8a4a6fd8
......@@ -674,13 +674,14 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
......@@ -694,15 +695,16 @@ def sparse_attn_indexer(
)
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
fp8_mqa_logits_func = fp8_mqa_logits
if current_platform.is_rocm():
# from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
......@@ -712,10 +714,15 @@ def sparse_attn_indexer(
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k_fp8,
weights[chunk.token_start : chunk.token_end] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[chunk.token_start : chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
q_fp8[chunk.token_start : chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[
......@@ -766,11 +773,11 @@ def sparse_attn_indexer(
logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
max_model_len,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
......@@ -876,8 +883,8 @@ class Indexer(nn.Module):
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
dtype=torch.uint8,
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4 if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else self.head_dim,
dtype=torch.uint8 if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16,
prefix=f"{prefix}.k_cache",
cache_config=cache_config,
)
......@@ -907,27 +914,29 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None,
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None,
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
q_fp8 if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else q,
k,
weights,
self.quant_block_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment