Commit b0318fc0 authored by Atream's avatar Atream
Browse files

fix-hopper-flashinfer

parent 38333cf1
......@@ -50,7 +50,8 @@ class KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
backend = "fa2",
)
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
......
......@@ -54,7 +54,8 @@ class KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf
bsz_tensor=self.bsz_tensor_buf,
backend = "fa2",
)
def batch_embeddings(self, batch: ForwardBatchInput, device="cuda:0"):
......
......@@ -100,7 +100,8 @@ class MLAWrapper():
kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf,
bsz_tensor=self.batch_size_tensor_buf
bsz_tensor=self.batch_size_tensor_buf,
backend = "fa2",
)
self.need_plan = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment