Commit 806ca2be authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.15.1-dev' into v0.15.1-dev

parents 55a9f930 3b38e285
......@@ -1263,16 +1263,22 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.kv_b_proj, out_dtype=act_dtype
).T
assert kv_b_proj_weight.shape == (
expected_shape = (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
if kv_b_proj_weight.shape != expected_shape:
if kv_b_proj_weight.T.shape == expected_shape:
kv_b_proj_weight = kv_b_proj_weight.T.contiguous()
else:
raise ValueError(
f"kv_b_proj_weight.shape={kv_b_proj_weight.shape}, "
f"expected={expected_shape}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}"
)
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
......
......@@ -94,7 +94,7 @@ def sparse_attn_indexer(
((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunks:
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
if not current_platform.is_rocm(): # or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
......@@ -112,7 +112,16 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
else:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = op.mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment