Commit 25ec6a34 authored by zhuwenwen's avatar zhuwenwen
Browse files

update mqa_logits and paged_mqa_logits

parent 8a4a6fd8
...@@ -674,6 +674,7 @@ def sparse_attn_indexer( ...@@ -674,6 +674,7 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
kv_cache, kv_cache,
...@@ -694,6 +695,7 @@ def sparse_attn_indexer( ...@@ -694,6 +695,7 @@ def sparse_attn_indexer(
) )
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens] k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens] k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache( ops.cp_gather_indexer_k_quant_cache(
...@@ -712,10 +714,15 @@ def sparse_attn_indexer( ...@@ -712,10 +714,15 @@ def sparse_attn_indexer(
logits = fp8_mqa_logits_func( logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end], q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)), (k_fp8, k_scale.view(torch.float32)) if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else k_fp8,
weights[chunk.token_start : chunk.token_end], weights[chunk.token_start : chunk.token_end] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[chunk.token_start : chunk.token_end].to(torch.float32),
chunk.cu_seqlen_ks, chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke, chunk.cu_seqlen_ke,
q_fp8[chunk.token_start : chunk.token_end].shape[0],
k.shape[0],
64,
128,
True,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[ topk_indices = topk_indices_buffer[
...@@ -766,11 +773,11 @@ def sparse_attn_indexer( ...@@ -766,11 +773,11 @@ def sparse_attn_indexer(
logits = fp8_paged_mqa_logits_func( logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens, padded_q_fp8_decode_tokens,
kv_cache, kv_cache,
weights[:num_padded_tokens], weights[:num_padded_tokens] if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else weights[:num_padded_tokens].to(torch.float32),
decode_metadata.seq_lens, decode_metadata.seq_lens,
decode_metadata.block_table, decode_metadata.block_table,
decode_metadata.schedule_metadata, decode_metadata.schedule_metadata,
max_model_len=max_model_len, max_model_len,
) )
num_rows = logits.shape[0] num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
...@@ -876,8 +883,8 @@ class Indexer(nn.Module): ...@@ -876,8 +883,8 @@ class Indexer(nn.Module):
# where we store value in fp8 and scale in fp32 # where we store value in fp8 and scale in fp32
# per self.quant_block_size element # per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache( self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4, head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4 if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else self.head_dim,
dtype=torch.uint8, dtype=torch.uint8 if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else torch.bfloat16,
prefix=f"{prefix}.k_cache", prefix=f"{prefix}.k_cache",
cache_config=cache_config, cache_config=cache_config,
) )
...@@ -907,6 +914,7 @@ class Indexer(nn.Module): ...@@ -907,6 +914,7 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1) k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim) q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8( q_fp8, q_scale = per_token_group_quant_fp8(
q, q,
...@@ -918,6 +926,7 @@ class Indexer(nn.Module): ...@@ -918,6 +926,7 @@ class Indexer(nn.Module):
q_scale = q_scale.view(-1, self.n_head, 1) q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states) weights, _ = self.weights_proj(hidden_states)
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
weights = ( weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
) )
...@@ -927,7 +936,7 @@ class Indexer(nn.Module): ...@@ -927,7 +936,7 @@ class Indexer(nn.Module):
hidden_states, hidden_states,
self.k_cache.prefix, self.k_cache.prefix,
self.k_cache.kv_cache[0], self.k_cache.kv_cache[0],
q_fp8, q_fp8 if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" else q,
k, k,
weights, weights,
self.quant_block_size, self.quant_block_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment