Unverified Commit 2ed96c7a authored by lukec's avatar lukec Committed by GitHub
Browse files

fix flashmla bug (#5272)

parent 2aa3f5e2
...@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_q_heads = ( self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.num_local_heads = ( self.num_local_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
...@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads, Q_LEN * self.num_q_heads,
self.num_kv_heads, 1,
) )
self.forward_metadata = FlashMLADecodeMetadata( self.forward_metadata = FlashMLADecodeMetadata(
mla_metadata, mla_metadata,
...@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata( self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device), torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
Q_LEN * self.num_q_heads // self.num_kv_heads, Q_LEN * self.num_q_heads,
self.num_kv_heads, 1,
) )
self.cuda_graph_kv_indices = cuda_graph_kv_indices self.cuda_graph_kv_indices = cuda_graph_kv_indices
...@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads, Q_LEN * self.num_q_heads,
self.num_kv_heads, 1,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
...@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
Q_LEN * self.num_q_heads // self.num_kv_heads, Q_LEN * self.num_q_heads,
self.num_kv_heads, 1,
) )
self.cuda_graph_mla_metadata.copy_(mla_metadata) self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits) self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment