Commit 5b9ad722 authored by lixh6's avatar lixh6
Browse files

Fix:GLM-5量化模型mla_attention layout修复&&sparse_attn fp8支持

parent a56e3da7
......@@ -1258,11 +1258,17 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
expected_shape = (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
)
if kv_b_proj_weight.shape != expected_shape:
if kv_b_proj_weight.T.shape == expected_shape:
kv_b_proj_weight = kv_b_proj_weight.T.contiguous()
else:
raise ValueError(
f"kv_b_proj_weight.shape={kv_b_proj_weight.shape}, "
f"expected={expected_shape}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
......
......@@ -94,7 +94,7 @@ def sparse_attn_indexer(
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
if not current_platform.is_rocm(): # or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
......@@ -113,6 +113,15 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ke,
)
else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment