Commit 477ac28a authored by Atream's avatar Atream
Browse files

fix-update-flashinfer_wrapper_local_chat

parent 5474be52
...@@ -435,7 +435,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -435,7 +435,6 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope.dtype, q_nope.dtype,
compressed_kv.dtype) compressed_kv.dtype)
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank) attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
""" """
k = ( k = (
torch.cat([compressed_kv, k_pe], dim=-1) torch.cat([compressed_kv, k_pe], dim=-1)
...@@ -465,7 +464,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): ...@@ -465,7 +464,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim] attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
else: else:
if past_key_value is not None: if past_key_value is not None:
......
...@@ -122,7 +122,7 @@ class MLAWrapper(): ...@@ -122,7 +122,7 @@ class MLAWrapper():
if kv_indices is None: if kv_indices is None:
assert self.max_batch_size == 1 assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf kv_indices = self.kv_indices_buf
self.wrapper.plan( self.wrapper.plan(
qo_indptr, qo_indptr,
kv_indptr, kv_indptr,
...@@ -189,7 +189,14 @@ class MLAWrapperSingleton(): ...@@ -189,7 +189,14 @@ class MLAWrapperSingleton():
@classmethod @classmethod
def reset_buffer(cls): def reset_buffer(cls):
for device, wrapper in cls.wrappers.items(): for device, wrapper in cls.wrappers.items():
wrapper.qo_indptr_buf[1] = 1 wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.
@classmethod
def update_buffer(cls, max_pages):
for device, wrapper in cls.wrappers.items():
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -293,6 +293,7 @@ ...@@ -293,6 +293,7 @@
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
absorb_for_prefill: False
# GPU 1: layers 15–29 # GPU 1: layers 15–29
- match: - match:
...@@ -302,6 +303,7 @@ ...@@ -302,6 +303,7 @@
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
absorb_for_prefill: False
# GPU 2: layers 30–44 # GPU 2: layers 30–44
- match: - match:
...@@ -311,6 +313,7 @@ ...@@ -311,6 +313,7 @@
kwargs: kwargs:
generate_device: "cuda:2" generate_device: "cuda:2"
prefill_device: "cuda:2" prefill_device: "cuda:2"
absorb_for_prefill: False
# GPU 3: layers 45–60 # GPU 3: layers 45–60
- match: - match:
...@@ -320,6 +323,7 @@ ...@@ -320,6 +323,7 @@
kwargs: kwargs:
generate_device: "cuda:3" generate_device: "cuda:3"
prefill_device: "cuda:3" prefill_device: "cuda:3"
absorb_for_prefill: False
# === Overall Model Replacement with Transfer Map === # === Overall Model Replacement with Transfer Map ===
......
...@@ -177,6 +177,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ...@@ -177,6 +177,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else: else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device) inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla: if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all() MLAWrapperSingleton.need_plan_all()
logits = model( logits = model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment