You need to sign in or sign up before continuing.
Unverified Commit bee6291d authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #1220 from kvcache-ai/fix-hopper-flashinfer

fix-hopper-flashinfer
parents b703cc9c b0318fc0
...@@ -50,7 +50,8 @@ class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ...@@ -50,7 +50,8 @@ class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph, self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
backend = "fa2",
) )
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
......
...@@ -54,7 +54,8 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ...@@ -54,7 +54,8 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
self.workspace_buffer, use_cuda_graph=use_cuda_graph, self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf, qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf, kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf bsz_tensor=self.bsz_tensor_buf,
backend = "fa2",
) )
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"): def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
......
...@@ -100,7 +100,8 @@ class MLAWrapper(): ...@@ -100,7 +100,8 @@ class MLAWrapper():
kv_indptr=self.kv_indptr_buf, kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf, kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf, kv_len_arr=self.kv_len_arr_buf,
bsz_tensor=self.batch_size_tensor_buf bsz_tensor=self.batch_size_tensor_buf,
backend = "fa2",
) )
self.need_plan = True self.need_plan = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment