Unverified Commit 8d463fe3 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Cutlass MLA decode - fix dtype error (#5868)

parent 26fc32d1
...@@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): ...@@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
o = cutlass_mla_decode( o = cutlass_mla_decode(
q_nope_and_q_pe=reshape_q, q_nope_and_q_pe=reshape_q.to(self.q_data_type),
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim), kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
seq_lens=forward_batch.seq_lens.to(torch.int32), seq_lens=forward_batch.seq_lens.to(torch.int32),
page_table=self.forward_metadata.block_kv_indices, page_table=self.forward_metadata.block_kv_indices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment