Commit f441aca2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mqa_logits and paged_mqa_logits

parent cc7715fd
...@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops ...@@ -83,6 +83,7 @@ from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
if current_platform.is_rocm(): if current_platform.is_rocm():
import lightop
from lightop import op, gemmopt from lightop import op, gemmopt
else: else:
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
...@@ -601,6 +602,7 @@ def sparse_attn_indexer( ...@@ -601,6 +602,7 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
kv_cache, kv_cache,
...@@ -613,6 +615,7 @@ def sparse_attn_indexer( ...@@ -613,6 +615,7 @@ def sparse_attn_indexer(
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim], k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device, device=k.device,
dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn)
...@@ -627,7 +630,7 @@ def sparse_attn_indexer( ...@@ -627,7 +630,7 @@ def sparse_attn_indexer(
chunk.cu_seq_lens, chunk.cu_seq_lens,
chunk.num_reqs, chunk.num_reqs,
) )
if not current_platform.is_rocm():
logits = fp8_mqa_logits( logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end], q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale), (k_fp8, k_scale),
...@@ -637,12 +640,18 @@ def sparse_attn_indexer( ...@@ -637,12 +640,18 @@ def sparse_attn_indexer(
) )
else: else:
logits = op.mqa_logits( logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end].half(), q_fp8[chunk.token_start:chunk.token_end],
(k_fp8.half(), k_scale), k,
weights[chunk.token_start:chunk.token_end], weights[chunk.token_start:chunk.token_end] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[chunk.token_start:chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
q_fp8[chunk.token_start:chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
) )
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1] dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None] topk_indices -= chunk.cu_seqlen_ks[:, None]
...@@ -692,14 +701,15 @@ def sparse_attn_indexer( ...@@ -692,14 +701,15 @@ def sparse_attn_indexer(
) )
else: else:
logits = gemmopt.paged_mqa_logits( logits = gemmopt.paged_mqa_logits(
padded_q_fp8_decode_tokens.half(), padded_q_fp8_decode_tokens,
kv_cache.half(), kv_cache if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else kv_cache.to(torch.bfloat16),
weights[:num_padded_tokens], weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens, decode_metadata.seq_lens,
decode_metadata.block_table, decode_metadata.block_table,
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_context_len=max_model_len, max_model_len,
) )
# padded query len # padded query len
current_device = padded_q_fp8_decode_tokens.device current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n padded_num_tokens = batch_size * next_n
...@@ -753,6 +763,7 @@ def sparse_attn_indexer_fake( ...@@ -753,6 +763,7 @@ def sparse_attn_indexer_fake(
# profile run # profile run
# NOTE(Chen): create the max possible flattened_kv. So that # NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage. # profile_run can get correct memory usage.
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4], _flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device, device=k.device,
dtype=torch.uint8) dtype=torch.uint8)
...@@ -845,6 +856,7 @@ class Indexer(nn.Module): ...@@ -845,6 +856,7 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim) q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q, q_fp8, q_scale = per_token_group_quant_fp8(q,
self.quant_block_size, self.quant_block_size,
...@@ -855,10 +867,12 @@ class Indexer(nn.Module): ...@@ -855,10 +867,12 @@ class Indexer(nn.Module):
q_scale = q_scale.view(-1, self.n_head, 1) q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states) weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = weights.unsqueeze( weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5 -1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1) weights = weights.squeeze(-1)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return torch.ops.vllm.sparse_attn_indexer( return torch.ops.vllm.sparse_attn_indexer(
hidden_states, hidden_states,
self.k_cache.prefix, self.k_cache.prefix,
...@@ -874,6 +888,22 @@ class Indexer(nn.Module): ...@@ -874,6 +888,22 @@ class Indexer(nn.Module):
self.max_total_seq_len, self.max_total_seq_len,
self.topk_indices_buffer, self.topk_indices_buffer,
) )
else:
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module): class DeepseekV2MLAAttention(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment